Skip to content
Snippets Groups Projects
Commit bed12f75 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: generalized equationcollection

parent 69ec4168
Branches
Tags
No related merge requests found
import sympy as sp
from copy import copy, deepcopy
from pystencils.sympyextensions import fastSubs, countNumberOfOperations
......@@ -20,52 +21,50 @@ class EquationCollection:
# ----------------------------------------- Creation ---------------------------------------------------------------
def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None):
def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
self.mainEquations = equations
self.subexpressions = subExpressions
if simplificationHints is None:
simplificationHints = {}
self.simplificationHints = simplificationHints
def symbolGen():
"""Use this generator to create new unused symbols for subexpressions"""
counter = 0
while True:
counter += 1
newSymbol = sp.Symbol("xi_" + str(counter))
if newSymbol in self.boundSymbols:
continue
yield newSymbol
class SymbolGen:
def __init__(self):
self._ctr = 0
def __iter__(self):
return self
def __next__(self):
self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr))
if subexpressionSymbolNameGenerator is None:
self.subexpressionSymbolNameGenerator = symbolGen()
self.subexpressionSymbolNameGenerator = SymbolGen()
else:
self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions):
"""
Returns a new equation collection, that has `newEquations` as mainEquations.
The `additionalSubExpressions` are appended to the existing subexpressions.
Simplifications hints are copied over.
"""
assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed"
res = EquationCollection(newEquations,
self.subexpressions + additionalSubExpressions,
self.simplificationHints)
res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
def copy(self, mainEquations=None, subexpressions=None):
res = deepcopy(self)
if mainEquations is not None:
res.mainEquations = mainEquations
if subexpressions is not None:
res.subexpressions = subexpressions
return res
def newWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpresions=False):
def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False):
"""
Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
Substitutions are made in the subexpression terms and the main equations
"""
newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
if addSubstitutionsAsSubexpresions:
if addSubstitutionsAsSubexpressions:
newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
return res
return self.copy(newEquations, newSubexpressions)
def addSimplificationHint(self, key, value):
"""
......@@ -178,41 +177,45 @@ class EquationCollection:
substitutionDict[otherSubexpressionEq.lhs] = newLhs
else:
processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
return EquationCollection(self.mainEquations + other.mainEquations,
self.subexpressions + processedOtherSubexpressionEquations)
return self.copy(self.mainEquations + other.mainEquations,
self.subexpressions + processedOtherSubexpressionEquations)
def extract(self, symbolsToExtract):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract = set(symbolsToExtract)
newEquations = []
def getDependentSymbols(self, symbolSequence):
"""Returns a list of symbols that depend on the passed symbols."""
subexprMap = {e.lhs: e.rhs for e in self.subexpressions}
handledSymbols = set()
queue = []
queue = list(symbolSequence)
def addSymbolsFromExpr(expr):
dependentSymbols = expr.atoms(sp.Symbol)
for ds in dependentSymbols:
if ds not in handledSymbols:
queue.append(ds)
handledSymbols.add(ds)
queue.append(ds)
for eq in self.allEquations:
if eq.lhs in symbolsToExtract:
newEquations.append(eq)
addSymbolsFromExpr(eq.rhs)
handledSymbols = set()
eqMap = {e.lhs: e.rhs for e in self.allEquations}
while len(queue) > 0:
e = queue.pop(0)
if e not in subexprMap:
if e in handledSymbols:
continue
else:
addSymbolsFromExpr(subexprMap[e])
if e in eqMap:
addSymbolsFromExpr(eqMap[e])
handledSymbols.add(e)
return handledSymbols
def extract(self, symbolsToExtract):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract = set(symbolsToExtract)
dependentSymbols = self.getDependentSymbols(symbolsToExtract)
newEquations = []
for eq in self.allEquations:
if eq.lhs in symbolsToExtract:
newEquations.append(eq)
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract]
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
return EquationCollection(newEquations, newSubExpr)
def newWithoutUnusedSubexpressions(self):
......@@ -221,18 +224,30 @@ class EquationCollection:
allLhs = [eq.lhs for eq in self.mainEquations]
return self.extract(allLhs)
def insertSubexpressions(self):
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
"""Returns a new equation collection by inserting all subexpressions into the main equations"""
if len(self.subexpressions) == 0:
return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints)
subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
return self.copy()
subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep)
keptSubexpressions = []
if self.subexpressions[0].lhs in subexpressionSymbolsToKeep:
subsDict = {}
keptSubexpressions = self.subexpressions[0]
else:
subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
subExpr = [e for e in self.subexpressions]
for i in range(1, len(subExpr)):
subExpr[i] = fastSubs(subExpr[i], subsDict)
subsDict[subExpr[i].lhs] = subExpr[i].rhs
if subExpr[i].lhs in subexpressionSymbolsToKeep:
keptSubexpressions.append(subExpr[i])
else:
subsDict[subExpr[i].lhs] = subExpr[i].rhs
newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
return EquationCollection(newEq, [], self.simplificationHints)
return self.copy(newEq, keptSubexpressions)
def lambdify(self, symbols, module=None, fixedSymbols={}):
"""
......@@ -241,7 +256,7 @@ class EquationCollection:
:param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification
"""
eqs = self.newWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
def f(*args, **kwargs):
......
import sympy as sp
from pystencils.equationcollection import EquationCollection
from pystencils.sympyextensions import replaceAdditive
......@@ -21,21 +20,18 @@ def sympyCSE(equationCollection):
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
return equationCollection.copy(modifiedUpdateEquations, newSubexpressions)
def applyOnAllEquations(equationCollection, operation):
"""Applies sympy expand operation to all equations in collection"""
result = [operation(s) for s in equationCollection.mainEquations]
return equationCollection.newWithAdditionalSubexpressions(result, [])
return equationCollection.copy(result)
def applyOnAllSubexpressions(equationCollection, operation):
return EquationCollection(equationCollection.mainEquations,
[operation(s) for s in equationCollection.subexpressions],
equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
return equationCollection.copy(equationCollection.mainEquations,
[operation(s) for s in equationCollection.subexpressions])
def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
......@@ -49,8 +45,7 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
result.append(sp.Eq(s.lhs, newRhs))
return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
return equationCollection.copy(equationCollection.mainEquations, result)
def subexpressionSubstitutionInMainEquations(equationCollection):
......@@ -61,7 +56,7 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
for subExpr in equationCollection.subexpressions:
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
result.append(sp.Eq(s.lhs, newRhs))
return equationCollection.newWithAdditionalSubexpressions(result, [])
return equationCollection.copy(result)
def addSubexpressionsForDivisions(equationCollection):
......@@ -80,4 +75,4 @@ def addSubexpressionsForDivisions(equationCollection):
newSymbolGen = equationCollection.subexpressionSymbolNameGenerator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
return equationCollection.newWithSubstitutionsApplied(substitutions, True)
return equationCollection.copyWithSubstitutionsApplied(substitutions, True)
......@@ -14,7 +14,11 @@ def fastSubs(term, subsDict):
return expr
paramList = [visit(a) for a in expr.args]
return expr if not paramList else expr.func(*paramList)
return visit(term)
if len(subsDict) == 0:
return term
else:
return visit(term)
def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment