From 2ab418a8af11bb7d02fb645d05c1d69f22a71013 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Thu, 12 Jan 2017 14:14:40 +0100 Subject: [PATCH] new_lbm: Adapting moment-based simplifications --- equationcollection/equationcollection.py | 23 +++++++++++++++-------- equationcollection/simplifications.py | 14 ++++++++------ sympyextensions.py | 1 + 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py index 0f0c5e9b6..ede7d66fb 100644 --- a/equationcollection/equationcollection.py +++ b/equationcollection/equationcollection.py @@ -20,7 +20,7 @@ class EquationCollection: # ----------------------------------------- Creation --------------------------------------------------------------- - def __init__(self, equations, subExpressions, simplificationHints={}): + def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None): self.mainEquations = equations self.subexpressions = subExpressions self.simplificationHints = simplificationHints @@ -35,7 +35,10 @@ class EquationCollection: continue yield newSymbol - self.subexpressionSymbolNameGenerator = symbolGen() + if subexpressionSymbolNameGenerator is None: + self.subexpressionSymbolNameGenerator = symbolGen() + else: + self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions): """ @@ -44,9 +47,11 @@ class EquationCollection: Simplifications hints are copied over. """ assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed" - return EquationCollection(newEquations, - self.subexpressions + additionalSubExpressions, - self.simplificationHints) + res = EquationCollection(newEquations, + self.subexpressions + additionalSubExpressions, + self.simplificationHints) + res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator + return res def newWithSubstitutionsApplied(self, substitutionDict): """ @@ -55,7 +60,9 @@ class EquationCollection: """ newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] - return EquationCollection(newEquations, newSubexpressions, self.simplificationHints) + res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints) + res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator + return res def addSimplificationHint(self, key, value): """ @@ -190,7 +197,7 @@ class EquationCollection: queue.append(ds) handledSymbols.add(ds) - for eq in self.mainEquations: + for eq in self.allEquations: if eq.lhs in symbolsToExtract: newEquations.append(eq) addSymbolsFromExpr(eq.rhs) @@ -202,7 +209,7 @@ class EquationCollection: else: addSymbolsFromExpr(subexprMap[e]) - newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols] + newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract] return EquationCollection(newEquations, newSubExpr) def newWithoutUnusedSubexpressions(self): diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py index 24a21fdf5..c8e93702b 100644 --- a/equationcollection/simplifications.py +++ b/equationcollection/simplifications.py @@ -10,7 +10,8 @@ def sympyCSE(equationCollection): with the additional subexpressions found """ symbolGen = equationCollection.subexpressionSymbolNameGenerator - replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, symbols=symbolGen) + replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, + symbols=symbolGen) replacementEqs = [sp.Eq(*r) for r in replacements] modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] @@ -20,7 +21,8 @@ def sympyCSE(equationCollection): topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] - return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints) + return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints, + equationCollection.subexpressionSymbolNameGenerator) def applyOnAllEquations(equationCollection, operation): @@ -32,7 +34,8 @@ def applyOnAllEquations(equationCollection, operation): def applyOnAllSubexpressions(equationCollection, operation): return EquationCollection(equationCollection.mainEquations, [operation(s) for s in equationCollection.subexpressions], - equationCollection.simplificationHints) + equationCollection.simplificationHints, + equationCollection.subexpressionSymbolNameGenerator) def subexpressionSubstitutionInExistingSubexpressions(equationCollection): @@ -46,7 +49,8 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection): newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) result.append(sp.Eq(s.lhs, newRhs)) - return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints) + return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints, + equationCollection.subexpressionSymbolNameGenerator) def subexpressionSubstitutionInUpdateEquations(equationCollection): @@ -58,5 +62,3 @@ def subexpressionSubstitutionInUpdateEquations(equationCollection): newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) result.append(sp.Eq(s.lhs, newRhs)) return equationCollection.newWithAdditionalSubexpressions(result, []) - - diff --git a/sympyextensions.py b/sympyextensions.py index 261438c64..e02b58be9 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -330,3 +330,4 @@ def matrixFromColumnVectors(columnVectors): def commonDenominator(expr): denominators = [r.q for r in expr.atoms(sp.Rational)] return sp.lcm(denominators) + -- GitLab