From 350bb4d52e8d33a0bd8a0e4ab3a46316590d7cd0 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 24 May 2017 09:56:07 +0200 Subject: [PATCH] Automatic Chapman Enskog Analysis of moment-based methods --- equationcollection/equationcollection.py | 12 ++++-- sympyextensions.py | 52 ++++++++++++++++++++++++ types.py | 3 +- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py index 86fb5700c..9c529309d 100644 --- a/equationcollection/equationcollection.py +++ b/equationcollection/equationcollection.py @@ -57,13 +57,19 @@ class EquationCollection(object): res.subexpressions = subexpressions return res - def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False): + def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False, + substituteOnLhs=True): """ Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`. Substitutions are made in the subexpression terms and the main equations """ - newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] - newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] + if substituteOnLhs: + newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] + newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] + else: + newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] + newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations] + if addSubstitutionsAsSubexpressions: newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions newSubexpressions = sortEquationsTopologically(newSubexpressions) diff --git a/sympyextensions.py b/sympyextensions.py index 97169a4ce..e567d9cd8 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -1,9 +1,61 @@ import operator +from functools import reduce from collections import defaultdict, Sequence +import itertools import warnings import sympy as sp +def prod(seq): + """Takes a sequence and returns the product of all elements""" + return reduce(operator.mul, seq, 1) + + +def allIn(a, b): + """Tests if all elements of a container 'a' are contained in 'b'""" + return all(element in b for element in a) + + +def normalizeProduct(product): + """ + Expects a sympy expression that can be interpreted as a product and + - for a Mul node returns its factors ('args') + - for a Pow node with positive integer exponent returns a list of factors + - for other node types [product] is returned + """ + def handlePow(power): + if power.exp.is_integer and power.exp.is_number and power.exp > 0: + return [power.base] * power.exp + else: + return [power] + + if product.func == sp.Pow: + return handlePow(product) + elif product.func == sp.Mul: + result = [] + for a in product.args: + if a.func == sp.Pow: + result += handlePow(a) + else: + result.append(a) + return result + else: + return [product] + + +def productSymmetric(*args, withDiagonal=True): + """Similar to itertools.product but returns only values where the index is ascending i.e. values below diagonal""" + ranges = [range(len(a)) for a in args] + for idx in itertools.product(*ranges): + validIndex = True + for t in range(1, len(idx)): + if (withDiagonal and idx[t - 1] > idx[t]) or (not withDiagonal and idx[t - 1] >= idx[t]): + validIndex = False + break + if validIndex: + yield tuple(a[i] for a, i in zip(args, idx)) + + def fastSubs(term, subsDict): """Similar to sympy subs function. This version is much faster for big substitution dictionaries than sympy version""" diff --git a/types.py b/types.py index d4a6b6a53..43373f0cb 100644 --- a/types.py +++ b/types.py @@ -24,8 +24,7 @@ class TypedSymbol(sp.Symbol): def _hashable_content(self): superClassContents = list(super(TypedSymbol, self)._hashable_content()) - t = tuple(superClassContents + [hash(repr(self._dtype))]) - return t + return tuple(superClassContents + [hash(repr(self._dtype))]) def __getnewargs__(self): return self.name, self.dtype -- GitLab