From 2ab418a8af11bb7d02fb645d05c1d69f22a71013 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 12 Jan 2017 14:14:40 +0100
Subject: [PATCH] new_lbm: Adapting moment-based simplifications

---
 equationcollection/equationcollection.py | 23 +++++++++++++++--------
 equationcollection/simplifications.py    | 14 ++++++++------
 sympyextensions.py                       |  1 +
 3 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py
index 0f0c5e9b6..ede7d66fb 100644
--- a/equationcollection/equationcollection.py
+++ b/equationcollection/equationcollection.py
@@ -20,7 +20,7 @@ class EquationCollection:
 
     # ----------------------------------------- Creation ---------------------------------------------------------------
 
-    def __init__(self, equations, subExpressions, simplificationHints={}):
+    def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None):
         self.mainEquations = equations
         self.subexpressions = subExpressions
         self.simplificationHints = simplificationHints
@@ -35,7 +35,10 @@ class EquationCollection:
                     continue
                 yield newSymbol
 
-        self.subexpressionSymbolNameGenerator = symbolGen()
+        if subexpressionSymbolNameGenerator is None:
+            self.subexpressionSymbolNameGenerator = symbolGen()
+        else:
+            self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
 
     def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions):
         """
@@ -44,9 +47,11 @@ class EquationCollection:
         Simplifications hints are copied over.
         """
         assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed"
-        return EquationCollection(newEquations,
-                                  self.subexpressions + additionalSubExpressions,
-                                  self.simplificationHints)
+        res = EquationCollection(newEquations,
+                                 self.subexpressions + additionalSubExpressions,
+                                 self.simplificationHints)
+        res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
+        return res
 
     def newWithSubstitutionsApplied(self, substitutionDict):
         """
@@ -55,7 +60,9 @@ class EquationCollection:
         """
         newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
         newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
-        return EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
+        res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
+        res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
+        return res
 
     def addSimplificationHint(self, key, value):
         """
@@ -190,7 +197,7 @@ class EquationCollection:
                     queue.append(ds)
                     handledSymbols.add(ds)
 
-        for eq in self.mainEquations:
+        for eq in self.allEquations:
             if eq.lhs in symbolsToExtract:
                 newEquations.append(eq)
                 addSymbolsFromExpr(eq.rhs)
@@ -202,7 +209,7 @@ class EquationCollection:
             else:
                 addSymbolsFromExpr(subexprMap[e])
 
-        newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols]
+        newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract]
         return EquationCollection(newEquations, newSubExpr)
 
     def newWithoutUnusedSubexpressions(self):
diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py
index 24a21fdf5..c8e93702b 100644
--- a/equationcollection/simplifications.py
+++ b/equationcollection/simplifications.py
@@ -10,7 +10,8 @@ def sympyCSE(equationCollection):
     with the additional subexpressions found
     """
     symbolGen = equationCollection.subexpressionSymbolNameGenerator
-    replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, symbols=symbolGen)
+    replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations,
+                                 symbols=symbolGen)
     replacementEqs = [sp.Eq(*r) for r in replacements]
 
     modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)]
@@ -20,7 +21,8 @@ def sympyCSE(equationCollection):
     topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
     newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
 
-    return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints)
+    return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints,
+                              equationCollection.subexpressionSymbolNameGenerator)
 
 
 def applyOnAllEquations(equationCollection, operation):
@@ -32,7 +34,8 @@ def applyOnAllEquations(equationCollection, operation):
 def applyOnAllSubexpressions(equationCollection, operation):
     return EquationCollection(equationCollection.mainEquations,
                               [operation(s) for s in equationCollection.subexpressions],
-                              equationCollection.simplificationHints)
+                              equationCollection.simplificationHints,
+                              equationCollection.subexpressionSymbolNameGenerator)
 
 
 def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
@@ -46,7 +49,8 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
             newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
         result.append(sp.Eq(s.lhs, newRhs))
 
-    return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints)
+    return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints,
+                              equationCollection.subexpressionSymbolNameGenerator)
 
 
 def subexpressionSubstitutionInUpdateEquations(equationCollection):
@@ -58,5 +62,3 @@ def subexpressionSubstitutionInUpdateEquations(equationCollection):
             newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
         result.append(sp.Eq(s.lhs, newRhs))
     return equationCollection.newWithAdditionalSubexpressions(result, [])
-
-
diff --git a/sympyextensions.py b/sympyextensions.py
index 261438c64..e02b58be9 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -330,3 +330,4 @@ def matrixFromColumnVectors(columnVectors):
 def commonDenominator(expr):
     denominators = [r.q for r in expr.atoms(sp.Rational)]
     return sp.lcm(denominators)
+
-- 
GitLab