Skip to content
Snippets Groups Projects
Commit 2ab418a8 authored by Martin Bauer's avatar Martin Bauer
Browse files

new_lbm: Adapting moment-based simplifications

parent 406599d6
No related merge requests found
...@@ -20,7 +20,7 @@ class EquationCollection: ...@@ -20,7 +20,7 @@ class EquationCollection:
# ----------------------------------------- Creation --------------------------------------------------------------- # ----------------------------------------- Creation ---------------------------------------------------------------
def __init__(self, equations, subExpressions, simplificationHints={}): def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None):
self.mainEquations = equations self.mainEquations = equations
self.subexpressions = subExpressions self.subexpressions = subExpressions
self.simplificationHints = simplificationHints self.simplificationHints = simplificationHints
...@@ -35,7 +35,10 @@ class EquationCollection: ...@@ -35,7 +35,10 @@ class EquationCollection:
continue continue
yield newSymbol yield newSymbol
self.subexpressionSymbolNameGenerator = symbolGen() if subexpressionSymbolNameGenerator is None:
self.subexpressionSymbolNameGenerator = symbolGen()
else:
self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions): def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions):
""" """
...@@ -44,9 +47,11 @@ class EquationCollection: ...@@ -44,9 +47,11 @@ class EquationCollection:
Simplifications hints are copied over. Simplifications hints are copied over.
""" """
assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed" assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed"
return EquationCollection(newEquations, res = EquationCollection(newEquations,
self.subexpressions + additionalSubExpressions, self.subexpressions + additionalSubExpressions,
self.simplificationHints) self.simplificationHints)
res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
return res
def newWithSubstitutionsApplied(self, substitutionDict): def newWithSubstitutionsApplied(self, substitutionDict):
""" """
...@@ -55,7 +60,9 @@ class EquationCollection: ...@@ -55,7 +60,9 @@ class EquationCollection:
""" """
newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
return EquationCollection(newEquations, newSubexpressions, self.simplificationHints) res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
return res
def addSimplificationHint(self, key, value): def addSimplificationHint(self, key, value):
""" """
...@@ -190,7 +197,7 @@ class EquationCollection: ...@@ -190,7 +197,7 @@ class EquationCollection:
queue.append(ds) queue.append(ds)
handledSymbols.add(ds) handledSymbols.add(ds)
for eq in self.mainEquations: for eq in self.allEquations:
if eq.lhs in symbolsToExtract: if eq.lhs in symbolsToExtract:
newEquations.append(eq) newEquations.append(eq)
addSymbolsFromExpr(eq.rhs) addSymbolsFromExpr(eq.rhs)
...@@ -202,7 +209,7 @@ class EquationCollection: ...@@ -202,7 +209,7 @@ class EquationCollection:
else: else:
addSymbolsFromExpr(subexprMap[e]) addSymbolsFromExpr(subexprMap[e])
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols] newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract]
return EquationCollection(newEquations, newSubExpr) return EquationCollection(newEquations, newSubExpr)
def newWithoutUnusedSubexpressions(self): def newWithoutUnusedSubexpressions(self):
......
...@@ -10,7 +10,8 @@ def sympyCSE(equationCollection): ...@@ -10,7 +10,8 @@ def sympyCSE(equationCollection):
with the additional subexpressions found with the additional subexpressions found
""" """
symbolGen = equationCollection.subexpressionSymbolNameGenerator symbolGen = equationCollection.subexpressionSymbolNameGenerator
replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, symbols=symbolGen) replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations,
symbols=symbolGen)
replacementEqs = [sp.Eq(*r) for r in replacements] replacementEqs = [sp.Eq(*r) for r in replacements]
modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)]
...@@ -20,7 +21,8 @@ def sympyCSE(equationCollection): ...@@ -20,7 +21,8 @@ def sympyCSE(equationCollection):
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints) return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
def applyOnAllEquations(equationCollection, operation): def applyOnAllEquations(equationCollection, operation):
...@@ -32,7 +34,8 @@ def applyOnAllEquations(equationCollection, operation): ...@@ -32,7 +34,8 @@ def applyOnAllEquations(equationCollection, operation):
def applyOnAllSubexpressions(equationCollection, operation): def applyOnAllSubexpressions(equationCollection, operation):
return EquationCollection(equationCollection.mainEquations, return EquationCollection(equationCollection.mainEquations,
[operation(s) for s in equationCollection.subexpressions], [operation(s) for s in equationCollection.subexpressions],
equationCollection.simplificationHints) equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
def subexpressionSubstitutionInExistingSubexpressions(equationCollection): def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
...@@ -46,7 +49,8 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection): ...@@ -46,7 +49,8 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
result.append(sp.Eq(s.lhs, newRhs)) result.append(sp.Eq(s.lhs, newRhs))
return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints) return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
def subexpressionSubstitutionInUpdateEquations(equationCollection): def subexpressionSubstitutionInUpdateEquations(equationCollection):
...@@ -58,5 +62,3 @@ def subexpressionSubstitutionInUpdateEquations(equationCollection): ...@@ -58,5 +62,3 @@ def subexpressionSubstitutionInUpdateEquations(equationCollection):
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
result.append(sp.Eq(s.lhs, newRhs)) result.append(sp.Eq(s.lhs, newRhs))
return equationCollection.newWithAdditionalSubexpressions(result, []) return equationCollection.newWithAdditionalSubexpressions(result, [])
...@@ -330,3 +330,4 @@ def matrixFromColumnVectors(columnVectors): ...@@ -330,3 +330,4 @@ def matrixFromColumnVectors(columnVectors):
def commonDenominator(expr): def commonDenominator(expr):
denominators = [r.q for r in expr.atoms(sp.Rational)] denominators = [r.q for r in expr.atoms(sp.Rational)]
return sp.lcm(denominators) return sp.lcm(denominators)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment