From 350bb4d52e8d33a0bd8a0e4ab3a46316590d7cd0 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 24 May 2017 09:56:07 +0200
Subject: [PATCH] Automatic Chapman Enskog Analysis of moment-based methods

---
 equationcollection/equationcollection.py | 12 ++++--
 sympyextensions.py                       | 52 ++++++++++++++++++++++++
 types.py                                 |  3 +-
 3 files changed, 62 insertions(+), 5 deletions(-)

diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py
index 86fb5700c..9c529309d 100644
--- a/equationcollection/equationcollection.py
+++ b/equationcollection/equationcollection.py
@@ -57,13 +57,19 @@ class EquationCollection(object):
             res.subexpressions = subexpressions
         return res
 
-    def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False):
+    def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False,
+                                     substituteOnLhs=True):
         """
         Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
         Substitutions are made in the subexpression terms and the main equations
         """
-        newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
-        newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
+        if substituteOnLhs:
+            newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
+            newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
+        else:
+            newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
+            newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations]
+
         if addSubstitutionsAsSubexpressions:
             newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
             newSubexpressions = sortEquationsTopologically(newSubexpressions)
diff --git a/sympyextensions.py b/sympyextensions.py
index 97169a4ce..e567d9cd8 100644
--- a/sympyextensions.py
+++ b/sympyextensions.py
@@ -1,9 +1,61 @@
 import operator
+from functools import reduce
 from collections import defaultdict, Sequence
+import itertools
 import warnings
 import sympy as sp
 
 
+def prod(seq):
+    """Takes a sequence and returns the product of all elements"""
+    return reduce(operator.mul, seq, 1)
+
+
+def allIn(a, b):
+    """Tests if all elements of a container 'a' are contained in 'b'"""
+    return all(element in b for element in a)
+
+
+def normalizeProduct(product):
+    """
+    Expects a sympy expression that can be interpreted as a product and
+    - for a Mul node returns its factors ('args')
+    - for a Pow node with positive integer exponent returns a list of factors
+    - for other node types [product] is returned
+    """
+    def handlePow(power):
+        if power.exp.is_integer and power.exp.is_number and power.exp > 0:
+            return [power.base] * power.exp
+        else:
+            return [power]
+
+    if product.func == sp.Pow:
+        return handlePow(product)
+    elif product.func == sp.Mul:
+        result = []
+        for a in product.args:
+            if a.func == sp.Pow:
+                result += handlePow(a)
+            else:
+                result.append(a)
+        return result
+    else:
+        return [product]
+
+
+def productSymmetric(*args, withDiagonal=True):
+    """Similar to itertools.product but returns only values where the index is ascending i.e. values below diagonal"""
+    ranges = [range(len(a)) for a in args]
+    for idx in itertools.product(*ranges):
+        validIndex = True
+        for t in range(1, len(idx)):
+            if (withDiagonal and idx[t - 1] > idx[t]) or (not withDiagonal and idx[t - 1] >= idx[t]):
+                validIndex = False
+                break
+        if validIndex:
+            yield tuple(a[i] for a, i in zip(args, idx))
+
+
 def fastSubs(term, subsDict):
     """Similar to sympy subs function.
     This version is much faster for big substitution dictionaries than sympy version"""
diff --git a/types.py b/types.py
index d4a6b6a53..43373f0cb 100644
--- a/types.py
+++ b/types.py
@@ -24,8 +24,7 @@ class TypedSymbol(sp.Symbol):
 
     def _hashable_content(self):
         superClassContents = list(super(TypedSymbol, self)._hashable_content())
-        t = tuple(superClassContents + [hash(repr(self._dtype))])
-        return t
+        return tuple(superClassContents + [hash(repr(self._dtype))])
 
     def __getnewargs__(self):
         return self.name, self.dtype
-- 
GitLab