import sympy as sp from pystencils.equationcollection import EquationCollection from pystencils.sympyextensions import replaceAdditive def sympyCSE(equationCollection): """ Searches for common subexpressions inside the equation collection, in both the existing subexpressions as well as the equations themselves. It uses the sympy subexpression detection to do this. Return a new equation collection with the additional subexpressions found """ symbolGen = equationCollection.subexpressionSymbolNameGenerator replacements, newEq = sp.cse(equationCollection.subexpressions + equationCollection.mainEquations, symbols=symbolGen) replacementEqs = [sp.Eq(*r) for r in replacements] modifiedSubexpressions = newEq[:len(equationCollection.subexpressions)] modifiedUpdateEquations = newEq[len(equationCollection.subexpressions):] newSubexpressions = replacementEqs + modifiedSubexpressions topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints) def applyOnAllEquations(equationCollection, operation): """Applies sympy expand operation to all equations in collection""" result = [operation(s) for s in equationCollection.mainEquations] return equationCollection.createNewWithAdditionalSubexpressions(result, []) def applyOnAllSubexpressions(equationCollection, operation): return EquationCollection(equationCollection.mainEquations, [operation(s) for s in equationCollection.subexpressions], equationCollection.simplificationHints) def subexpressionSubstitutionInExistingSubexpressions(equationCollection): """Goes through the subexpressions list and replaces the term in the following subexpressions""" result = [] for outerCtr, s in enumerate(equationCollection.subexpressions): newRhs = s.rhs for innerCtr in range(outerCtr): subExpr = equationCollection.subexpressions[innerCtr] newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) result.append(sp.Eq(s.lhs, newRhs)) return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints) def subexpressionSubstitutionInUpdateEquations(equationCollection): """Replaces already existing subexpressions in the equations of the equationCollection""" result = [] for s in equationCollection.mainEquations: newRhs = s.rhs for subExpr in equationCollection.subexpressions: newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) result.append(sp.Eq(s.lhs, newRhs)) return equationCollection.createNewWithAdditionalSubexpressions(result, [])