From adf98ef6833d23251ade77a2f6ba691ebb0a6fd6 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 20 Dec 2016 12:24:31 +0100 Subject: [PATCH] equation collection into pystencils --- equationcollection/__init__.py | 1 + equationcollection/equationcollection.py | 206 +++++++++++++++ equationcollection/simplifications.py | 62 +++++ sympyextensions.py | 319 +++++++++++++++++++++++ 4 files changed, 588 insertions(+) create mode 100644 equationcollection/__init__.py create mode 100644 equationcollection/equationcollection.py create mode 100644 equationcollection/simplifications.py create mode 100644 sympyextensions.py diff --git a/equationcollection/__init__.py b/equationcollection/__init__.py new file mode 100644 index 000000000..295aa35bb --- /dev/null +++ b/equationcollection/__init__.py @@ -0,0 +1 @@ +from pystencils.equationcollection.equationcollection import EquationCollection diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py new file mode 100644 index 000000000..fc29cbf98 --- /dev/null +++ b/equationcollection/equationcollection.py @@ -0,0 +1,206 @@ +import sympy as sp +from pystencils.transformations import fastSubs + + +class EquationCollection: + """ + A collection of equations with subexpression definitions, also represented as equations, + that are used in the main equations. EquationCollections can be passed to simplification methods. + These simplification methods can change the subexpressions, but the number and + left hand side of the main equations themselves is not altered. + Additionally a dictionary of simplification hints is stored, which are set by the functions that create + equation collections to transport information to the simplification system. + + :ivar mainEquations: list of sympy equations + :ivar subexpressions: list of sympy equations defining subexpressions used in main equations + :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are + used by the simplification system. See documentation of the simplification rules for + potentially required hints and their meaning. + """ + + # ----------------------------------------- Creation --------------------------------------------------------------- + + def __init__(self, equations, subExpressions, simplificationHints={}): + self.mainEquations = equations + self.subexpressions = subExpressions + self.simplificationHints = simplificationHints + + def symbolGen(): + """Use this generator to create new unused symbols for subexpressions""" + counter = 0 + while True: + counter += 1 + newSymbol = sp.Symbol("xi_" + str(counter)) + if newSymbol in self.boundSymbols: + continue + yield newSymbol + + self.subexpressionSymbolNameGenerator = symbolGen() + + def createNewWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions): + assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed" + return EquationCollection(newEquations, + self.subexpressions + additionalSubExpressions, + self.simplificationHints) + + def createNewWithSubstitutionsApplied(self, substitutionDict): + newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] + newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] + return EquationCollection(newEquations, newSubexpressions, self.simplificationHints) + + def addSimplificationHint(self, key, value): + assert key not in self.simplificationHints, "This hint already exists" + self.simplificationHints[key] = value + + # ---------------------------------------------- Properties ------------------------------------------------------- + + @property + def allEquations(self): + return self.subexpressions + self.mainEquations + + @property + def freeSymbols(self): + """All symbols used in the equation collection, which have not been defined inside the equation system""" + freeSymbols = set() + for eq in self.allEquations: + freeSymbols.update(eq.rhs.atoms(sp.Symbol)) + return freeSymbols - self.boundSymbols + + @property + def boundSymbols(self): + """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" + boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) + assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \ + "Not in SSA form - same symbol assigned multiple times" + return boundSymbolsSet + + @property + def definedSymbols(self): + """All symbols that occur as left-hand-sides of the main equations""" + return set([eq.lhs for eq in self.mainEquations]) + + # ----------------------------------------- Display and Printing ------------------------------------------------- + + def _repr_html_(self): + def makeHtmlEquationTable(equations): + noBorder = 'style="border:none"' + htmlTable = '<table style="border:none; width: 100%; ">' + line = '<tr {nb}> <td {nb}>${lhs}$</td> <td {nb}>$=$</td> ' \ + '<td style="border:none; width: 100%;">${rhs}$</td> </tr>' + for eq in equations: + formatDict = {'lhs': sp.latex(eq.lhs), + 'rhs': sp.latex(eq.rhs), + 'nb': noBorder, } + htmlTable += line.format(**formatDict) + htmlTable += "</table>" + return htmlTable + + result = "" + if len(self.subexpressions) > 0: + result += "<div>Subexpressions:<div>" + result += makeHtmlEquationTable(self.subexpressions) + result += "<div>Main Equations:<div>" + result += makeHtmlEquationTable(self.mainEquations) + return result + + def __repr__(self): + return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) + + # ------------------------------------- Manipulation ------------------------------------------------------------ + + def merge(self, other): + """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" + ownDefs = set([e.lhs for e in self.mainEquations]) + otherDefs = set([e.lhs for e in other.mainEquations]) + assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" + + ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} + substitutionDict = {} + + processedOtherSubexpressionEquations = [] + for otherSubexpressionEq in other.subexpressions: + if otherSubexpressionEq.lhs in ownSubexpressionSymbols: + if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]: + continue # exact the same subexpression equation exists already + else: + # different definition - a new name has to be introduced + newLhs = self.subexpressionSymbolNameGenerator() + newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) + processedOtherSubexpressionEquations.append(newEq) + substitutionDict[otherSubexpressionEq.lhs] = newLhs + else: + processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) + return EquationCollection(self.mainEquations + other.mainEquations, + self.subexpressions + processedOtherSubexpressionEquations) + + def extract(self, symbolsToExtract): + """ + Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and + only the necessary subexpressions that are used in these equations + """ + symbolsToExtract = set(symbolsToExtract) + newEquations = [] + + subexprMap = {e.lhs: e.rhs for e in self.subexpressions} + handledSymbols = set() + queue = [] + + def addSymbolsFromExpr(expr): + dependentSymbols = expr.atoms(sp.Symbol) + for ds in dependentSymbols: + if ds not in handledSymbols: + queue.append(ds) + handledSymbols.add(ds) + + for eq in self.mainEquations: + if eq.lhs in symbolsToExtract: + newEquations.append(eq) + addSymbolsFromExpr(eq.rhs) + + while len(queue) > 0: + e = queue.pop(0) + if e not in subexprMap: + continue + else: + addSymbolsFromExpr(subexprMap[e]) + + newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols] + return EquationCollection(newEquations, newSubExpr) + + def newWithoutUnusedSubexpressions(self): + """Returns a new equation collection containing only the subexpressions that + are used/referenced in the equations""" + allLhs = [eq.lhs for eq in self.mainEquations] + return self.extract(allLhs) + + def insertSubexpressions(self): + """Returns a new equation collection by inserting all subexpressions into the main equations""" + if len(self.subexpressions) == 0: + return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints) + subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} + subExpr = [e for e in self.subexpressions] + for i in range(1, len(subExpr)): + subExpr[i] = fastSubs(subExpr[i], subsDict) + subsDict[subExpr[i].lhs] = subExpr[i].rhs + + newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] + return EquationCollection(newEq, [], self.simplificationHints) + + def lambdify(self, symbols, module=None, fixedSymbols={}): + """ + Returns a function to evaluate this equation collection + :param symbols: symbol(s) which are the parameter for the created function + :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' + :param fixedSymbols: dictionary with substitutions, that are applied before lambdification + """ + eqs = self.createNewWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations + print('abc') + for eq in eqs: + print(eq) + sp.lambdify(eq.rhs, symbols, module) + lambdas = {eq.lhs: sp.lambdify(eq.rhs, symbols, module) for eq in eqs} + + def f(*args, **kwargs): + return {s: f(*args, **kwargs) for s, f in lambdas.items()} + + return f diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py new file mode 100644 index 000000000..cd5912994 --- /dev/null +++ b/equationcollection/simplifications.py @@ -0,0 +1,62 @@ +import sympy as sp +from pystencils.equationcollection import EquationCollection +from pystencils.sympyextensions import replaceAdditive + + +def sympyCSE(equationCollection): + """ + Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well + as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection + with the additional subexpressions found + """ + symbolGen = equationCollection.subexpressionSymbolNameGenerator + replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, symbols=symbolGen) + replacementEqs = [sp.Eq(*r) for r in replacements] + + modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] + modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):] + + newSubexpressions = replacementEqs + modifiedSubexpressions + topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) + newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] + + return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints) + + +def applyOnAllEquations(equationCollection, operation): + """Applies sympy expand operation to all equations in collection""" + result = [operation(s) for s in equationCollection.mainEquations] + return equationCollection.createNewWithAdditionalSubexpressions(result, []) + + +def applyOnAllSubexpressions(equationCollection, operation): + return EquationCollection(equationCollection.mainEquations, + [operation(s) for s in equationCollection.subexpressions], + equationCollection.simplificationHints) + + +def subexpressionSubstitutionInExistingSubexpressions(equationCollection): + """Goes through the subexpressions list and replaces the term in the following subexpressions""" + result = [] + for outerCtr, s in enumerate(equationCollection.subexpressions): + newRhs = s.rhs + for innerCtr in range(outerCtr): + subExpr = equationCollection.subexpressions[innerCtr] + newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) + newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) + result.append(sp.Eq(s.lhs, newRhs)) + + return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints) + + +def subexpressionSubstitutionInUpdateEquations(equationCollection): + """Replaces already existing subexpressions in the equations of the equationCollection""" + result = [] + for s in equationCollection.mainEquations: + newRhs = s.rhs + for subExpr in equationCollection.subexpressions: + newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) + result.append(sp.Eq(s.lhs, newRhs)) + return equationCollection.createNewWithAdditionalSubexpressions(result, []) + + diff --git a/sympyextensions.py b/sympyextensions.py new file mode 100644 index 000000000..e89cf1374 --- /dev/null +++ b/sympyextensions.py @@ -0,0 +1,319 @@ +import sympy as sp +import operator +from collections import defaultdict, Sequence +import warnings + + +def fastSubs(term, subsDict): + """Similar to sympy subs function. + This version is much faster for big substitution dictionaries than sympy version""" + def visit(expr): + if expr in subsDict: + return subsDict[expr] + if not hasattr(expr, 'args'): + return expr + paramList = [visit(a) for a in expr.args] + return expr if not paramList else expr.func(*paramList) + return visit(term) + + +def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None): + """ + Transformation for replacing a given subexpression inside a sum + + Example 1: + expr = 3*x + 3 * y + replacement = k + subExpression = x+y + return = 3*k + + Example 2: + expr = 3*x + 3 * y + z + replacement = k + subExpression = x+y+z + return: + if minimalMatchingTerms >=3 the expression would not be altered + if smaller than 3 the result is 3*k - 2*z + + :param expr: input expression + :param replacement: expression that is inserted for subExpression (if found) + :param subExpression: expression to replace + :param requiredMatchReplacement: + - if float: the percentage of terms of the subExpression that has to be matched in order to replace + - if integer: the total number of terms that has to be matched in order to replace + - None: is equal to integer 1 + - if both match parameters are given, both restrictions have to be fulfilled (i.e. logical AND) + :param requiredMatchOriginal: + - if float: the percentage of terms of the original addition expression that has to be matched + - if integer: the total number of terms that has to be matched in order to replace + - None: is equal to integer 1 + :return: new expression with replacement + """ + def normalizeMatchParameter(matchParameter, expressingLength): + if matchParameter is None: + return 1 + elif isinstance(matchParameter, float): + assert 0 <= matchParameter <= 1 + res = int(matchParameter * expressingLength) + return max(res, 1) + elif isinstance(matchParameter, int): + assert matchParameter > 0 + return matchParameter + raise ValueError("Invalid parameter") + + normalizedReplacementMatch = normalizeMatchParameter(requiredMatchReplacement, len(subExpression.args)) + + def visit(currentExpr): + if currentExpr.is_Add: + exprMaxLength = max(len(currentExpr.args), len(subExpression.args)) + normalizedCurrentExprMatch = normalizeMatchParameter(requiredMatchOriginal, exprMaxLength) + exprCoeffs = currentExpr.as_coefficients_dict() + subexprCoeffDict = subExpression.as_coefficients_dict() + intersection = set(subexprCoeffDict.keys()).intersection(set(exprCoeffs)) + if len(intersection) >= max(normalizedReplacementMatch, normalizedCurrentExprMatch): + # find common factor + factors = defaultdict(lambda: 0) + skips = 0 + for commonSymbol in subexprCoeffDict.keys(): + if commonSymbol not in exprCoeffs: + skips += 1 + continue + factor = exprCoeffs[commonSymbol] / subexprCoeffDict[commonSymbol] + factors[sp.simplify(factor)] += 1 + + commonFactor = max(factors.items(), key=operator.itemgetter(1))[0] + if factors[commonFactor] >= max(normalizedCurrentExprMatch, normalizedReplacementMatch): + return currentExpr - commonFactor * subExpression + commonFactor * replacement + + # if no subexpression was found + paramList = [visit(a) for a in currentExpr.args] + if not paramList: + return currentExpr + else: + return currentExpr.func(*paramList, evaluate=False) + + return visit(expr) + + +def replaceSecondOrderProducts(expr, searchSymbols, positive=None, replaceMixed=None): + """ + Replaces second order mixed terms like x*y by 2* ( (x+y)**2 - x**2 - y**2 ) + This makes the term longer - simplify usually is undoing these - however this + transformation can be done to find more common sub-expressions + :param expr: input expression + :param searchSymbols: list of symbols that are searched for + Example: given [ x,y,z] terms like x*y, x*z, z*y are replaced + :param positive: there are two ways to do this substitution, either with term + (x+y)**2 or (x-y)**2 . if positive=True the first version is done, + if positive=False the second version is done, if positive=None the + sign is determined by the sign of the mixed term that is replaced + :param replaceMixed: if a list is passed here the expr x+y or x-y is replaced by a special new symbol + the replacement equation is added to the list + :return: + """ + if replaceMixed is not None: + mixedSymbolsReplaced = set([e.lhs for e in replaceMixed]) + + if expr.is_Mul: + distinctVelTerms = set() + nrOfVelTerms = 0 + otherFactors = 1 + for t in expr.args: + if t in searchSymbols: + nrOfVelTerms += 1 + distinctVelTerms.add(t) + else: + otherFactors *= t + if len(distinctVelTerms) == 2 and nrOfVelTerms == 2: + u, v = list(distinctVelTerms) + if positive is None: + otherFactorsWithoutSymbols = otherFactors + for s in otherFactors.atoms(sp.Symbol): + otherFactorsWithoutSymbols = otherFactorsWithoutSymbols.subs(s, 1) + positive = otherFactorsWithoutSymbols.is_positive + assert positive is not None + sign = 1 if positive else -1 + if replaceMixed is not None: + newSymbolStr = 'P' if positive else 'M' + mixedSymbolName = u.name + newSymbolStr + v.name + mixedSymbol = sp.Symbol(mixedSymbolName.replace("_", "")) + if mixedSymbol not in mixedSymbolsReplaced: + mixedSymbolsReplaced.add(mixedSymbol) + replaceMixed.append(sp.Eq(mixedSymbol, u + sign * v)) + else: + mixedSymbol = u + sign * v + return sp.Rational(1, 2) * sign * otherFactors * (mixedSymbol ** 2 - u ** 2 - v ** 2) + + paramList = [replaceSecondOrderProducts(a, searchSymbols, positive, replaceMixed) for a in expr.args] + result = expr.func(*paramList, evaluate=False) if paramList else expr + return result + + +def removeHigherOrderTerms(term, order=3, symbols=None): + """ + Remove all terms from a sum that contain 'order' or more factors of given 'symbols' + Example: symbols = ['u_x', 'u_y'] and order =2 + removes terms u_x**2, u_x*u_y, u_y**2, u_x**3, .... + """ + from sympy.core.power import Pow + from sympy.core.add import Add, Mul + + result = 0 + term = term.expand() + + if not symbols: + symbols = sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)])) + symbols += sp.symbols(" ".join(["u_%d" % (i,) for i in range(3)]), real=True) + + def velocityFactorsInProduct(product): + uFactorCount = 0 + for factor in product.args: + if type(factor) == Pow: + if factor.args[0] in symbols: + uFactorCount += factor.args[1] + if factor in symbols: + uFactorCount += 1 + return uFactorCount + + if type(term) == Mul: + if velocityFactorsInProduct(term) <= order: + return term + else: + return sp.Rational(0, 1) + + if type(term) != Add: + return term + + for sumTerm in term.args: + if velocityFactorsInProduct(sumTerm) <= order: + result += sumTerm + return result + + +def completeTheSquare(expr, symbolToComplete, newVariable): + """ + Transforms second order polynomial into only squared part i.e. + a*symbolToComplete**2 + b*symbolToComplete + c + is transformed into + newVariable**2 + d + + returns replacedExpr, "a tuple to to replace newVariable such that old expr comes out again" + + if given expr is not a second order polynomial: + return expr, None + """ + p = sp.Poly(expr, symbolToComplete) + coeffs = p.all_coeffs() + if len(coeffs) != 3: + return expr, None + a, b, _ = coeffs + expr = expr.subs(symbolToComplete, newVariable - b / (2 * a)) + return sp.simplify(expr), (newVariable, symbolToComplete + b / (2 * a)) + + +def makeExponentialFuncArgumentSquares(expr, variablesToCompleteSquares): + """Completes squares in arguments of exponential which makes them simpler to integrate + Very useful for integrating Maxwell-Boltzmann and its moment generating function""" + expr = sp.simplify(expr) + dim = len(variablesToCompleteSquares) + dummies = [sp.Dummy() for i in range(dim)] + + def visit(term): + if term.func == sp.exp: + expArg = term.args[0] + for i in range(dim): + expArg, substitution = completeTheSquare(expArg, variablesToCompleteSquares[i], dummies[i]) + return sp.exp(sp.simplify(expArg)) + else: + paramList = [visit(a) for a in term.args] + if not paramList: + return term + else: + return term.func(*paramList) + + result = visit(expr) + for i in range(dim): + result = result.subs(dummies[i], variablesToCompleteSquares[i]) + return result + + +def pow2mul(expr): + """ + Convert integer powers in an expression to Muls, like a**2 => a*a. + """ + pows = list(expr.atoms(sp.Pow)) + if any(not e.is_Integer for b, e in (i.as_base_exp() for i in pows)): + raise ValueError("A power contains a non-integer exponent") + repl = zip(pows, (sp.Mul(*[b]*e, evaluate=False) for b, e in (i.as_base_exp() for i in pows))) + return expr.subs(repl) + + +def extractMostCommonFactor(term): + """Processes a sum of fractions: determines the most common factor and splits term in common factor and rest""" + import operator + from collections import Counter + from sympy.functions import Abs + + coeffDict = term.as_coefficients_dict() + counter = Counter([Abs(v) for v in coeffDict.values()]) + commonFactor, occurances = max(counter.items(), key=operator.itemgetter(1)) + if occurances == 1 and (1 in counter): + commonFactor = 1 + return commonFactor, term / commonFactor + + +def countNumberOfOperations(term): + """ + Counts the number of additions, multiplications and division + :param term: a sympy term, equation or sequence of terms/equations + :return: a dictionary with 'adds', 'muls' and 'divs' keys + """ + result = {'adds': 0, 'muls': 0, 'divs': 0} + + if isinstance(term, Sequence): + for element in term: + r = countNumberOfOperations(element) + for operationName in result.keys(): + result[operationName] += r[operationName] + return result + elif isinstance(term, sp.Eq): + term = term.rhs + + term = term.evalf() + + def visit(t): + visitChildren = True + if t.func is sp.Add: + result['adds'] += len(t.args) - 1 + elif t.func is sp.Mul: + result['muls'] += len(t.args) - 1 + for a in t.args: + if a == 1 or a == -1: + result['muls'] -= 1 + elif t.func is sp.Float: + pass + elif isinstance(t, sp.Symbol): + pass + elif t.is_integer: + pass + elif t.func is sp.Pow: + visitChildren = False + if t.exp.is_integer and t.exp.is_number: + if t.exp >= 0: + result['muls'] += int(t.exp) - 1 + else: + result['muls'] -= 1 + result['divs'] += 1 + result['muls'] += (-int(t.exp)) - 1 + else: + warnings.warn("Counting operations: only integer exponents are supported in Pow, " + "counting will be inaccurate") + else: + warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") + + if visitChildren: + for a in t.args: + visit(a) + + visit(term) + return result -- GitLab