From bed12f7527a1813beccc06a897107f693ed06325 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 20 Jan 2017 13:52:56 +0100 Subject: [PATCH] pystencils: generalized equationcollection --- equationcollection/equationcollection.py | 127 +++++++++++++---------- equationcollection/simplifications.py | 19 ++-- sympyextensions.py | 6 +- 3 files changed, 83 insertions(+), 69 deletions(-) diff --git a/equationcollection/equationcollection.py b/equationcollection/equationcollection.py index 8c6c6f32f..2ea0c2b6f 100644 --- a/equationcollection/equationcollection.py +++ b/equationcollection/equationcollection.py @@ -1,4 +1,5 @@ import sympy as sp +from copy import copy, deepcopy from pystencils.sympyextensions import fastSubs, countNumberOfOperations @@ -20,52 +21,50 @@ class EquationCollection: # ----------------------------------------- Creation --------------------------------------------------------------- - def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None): + def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None): self.mainEquations = equations self.subexpressions = subExpressions + + if simplificationHints is None: + simplificationHints = {} + self.simplificationHints = simplificationHints - def symbolGen(): - """Use this generator to create new unused symbols for subexpressions""" - counter = 0 - while True: - counter += 1 - newSymbol = sp.Symbol("xi_" + str(counter)) - if newSymbol in self.boundSymbols: - continue - yield newSymbol + class SymbolGen: + def __init__(self): + self._ctr = 0 + + def __iter__(self): + return self + + def __next__(self): + self._ctr += 1 + return sp.Symbol("xi_" + str(self._ctr)) if subexpressionSymbolNameGenerator is None: - self.subexpressionSymbolNameGenerator = symbolGen() + self.subexpressionSymbolNameGenerator = SymbolGen() else: self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator - def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions): - """ - Returns a new equation collection, that has `newEquations` as mainEquations. - The `additionalSubExpressions` are appended to the existing subexpressions. - Simplifications hints are copied over. - """ - assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed" - res = EquationCollection(newEquations, - self.subexpressions + additionalSubExpressions, - self.simplificationHints) - res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator + def copy(self, mainEquations=None, subexpressions=None): + res = deepcopy(self) + if mainEquations is not None: + res.mainEquations = mainEquations + if subexpressions is not None: + res.subexpressions = subexpressions return res - def newWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpresions=False): + def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False): """ Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`. Substitutions are made in the subexpression terms and the main equations """ newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] - if addSubstitutionsAsSubexpresions: + if addSubstitutionsAsSubexpressions: newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions - res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints) - res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator - return res + return self.copy(newEquations, newSubexpressions) def addSimplificationHint(self, key, value): """ @@ -178,41 +177,45 @@ class EquationCollection: substitutionDict[otherSubexpressionEq.lhs] = newLhs else: processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) - return EquationCollection(self.mainEquations + other.mainEquations, - self.subexpressions + processedOtherSubexpressionEquations) + return self.copy(self.mainEquations + other.mainEquations, + self.subexpressions + processedOtherSubexpressionEquations) - def extract(self, symbolsToExtract): - """ - Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and - only the necessary subexpressions that are used in these equations - """ - symbolsToExtract = set(symbolsToExtract) - newEquations = [] + def getDependentSymbols(self, symbolSequence): + """Returns a list of symbols that depend on the passed symbols.""" - subexprMap = {e.lhs: e.rhs for e in self.subexpressions} - handledSymbols = set() - queue = [] + queue = list(symbolSequence) def addSymbolsFromExpr(expr): dependentSymbols = expr.atoms(sp.Symbol) for ds in dependentSymbols: - if ds not in handledSymbols: - queue.append(ds) - handledSymbols.add(ds) + queue.append(ds) - for eq in self.allEquations: - if eq.lhs in symbolsToExtract: - newEquations.append(eq) - addSymbolsFromExpr(eq.rhs) + handledSymbols = set() + eqMap = {e.lhs: e.rhs for e in self.allEquations} while len(queue) > 0: e = queue.pop(0) - if e not in subexprMap: + if e in handledSymbols: continue - else: - addSymbolsFromExpr(subexprMap[e]) + if e in eqMap: + addSymbolsFromExpr(eqMap[e]) + handledSymbols.add(e) + + return handledSymbols + + def extract(self, symbolsToExtract): + """ + Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and + only the necessary subexpressions that are used in these equations + """ + symbolsToExtract = set(symbolsToExtract) + dependentSymbols = self.getDependentSymbols(symbolsToExtract) + newEquations = [] + for eq in self.allEquations: + if eq.lhs in symbolsToExtract: + newEquations.append(eq) - newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract] + newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract] return EquationCollection(newEquations, newSubExpr) def newWithoutUnusedSubexpressions(self): @@ -221,18 +224,30 @@ class EquationCollection: allLhs = [eq.lhs for eq in self.mainEquations] return self.extract(allLhs) - def insertSubexpressions(self): + def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): """Returns a new equation collection by inserting all subexpressions into the main equations""" if len(self.subexpressions) == 0: - return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints) - subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} + return self.copy() + + subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep) + + keptSubexpressions = [] + if self.subexpressions[0].lhs in subexpressionSymbolsToKeep: + subsDict = {} + keptSubexpressions = self.subexpressions[0] + else: + subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} + subExpr = [e for e in self.subexpressions] for i in range(1, len(subExpr)): subExpr[i] = fastSubs(subExpr[i], subsDict) - subsDict[subExpr[i].lhs] = subExpr[i].rhs + if subExpr[i].lhs in subexpressionSymbolsToKeep: + keptSubexpressions.append(subExpr[i]) + else: + subsDict[subExpr[i].lhs] = subExpr[i].rhs newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] - return EquationCollection(newEq, [], self.simplificationHints) + return self.copy(newEq, keptSubexpressions) def lambdify(self, symbols, module=None, fixedSymbols={}): """ @@ -241,7 +256,7 @@ class EquationCollection: :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' :param fixedSymbols: dictionary with substitutions, that are applied before lambdification """ - eqs = self.newWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations + eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} def f(*args, **kwargs): diff --git a/equationcollection/simplifications.py b/equationcollection/simplifications.py index 412a889cc..328373599 100644 --- a/equationcollection/simplifications.py +++ b/equationcollection/simplifications.py @@ -1,5 +1,4 @@ import sympy as sp -from pystencils.equationcollection import EquationCollection from pystencils.sympyextensions import replaceAdditive @@ -21,21 +20,18 @@ def sympyCSE(equationCollection): topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions]) newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs] - return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints, - equationCollection.subexpressionSymbolNameGenerator) + return equationCollection.copy(modifiedUpdateEquations, newSubexpressions) def applyOnAllEquations(equationCollection, operation): """Applies sympy expand operation to all equations in collection""" result = [operation(s) for s in equationCollection.mainEquations] - return equationCollection.newWithAdditionalSubexpressions(result, []) + return equationCollection.copy(result) def applyOnAllSubexpressions(equationCollection, operation): - return EquationCollection(equationCollection.mainEquations, - [operation(s) for s in equationCollection.subexpressions], - equationCollection.simplificationHints, - equationCollection.subexpressionSymbolNameGenerator) + return equationCollection.copy(equationCollection.mainEquations, + [operation(s) for s in equationCollection.subexpressions]) def subexpressionSubstitutionInExistingSubexpressions(equationCollection): @@ -49,8 +45,7 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection): newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs) result.append(sp.Eq(s.lhs, newRhs)) - return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints, - equationCollection.subexpressionSymbolNameGenerator) + return equationCollection.copy(equationCollection.mainEquations, result) def subexpressionSubstitutionInMainEquations(equationCollection): @@ -61,7 +56,7 @@ def subexpressionSubstitutionInMainEquations(equationCollection): for subExpr in equationCollection.subexpressions: newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0) result.append(sp.Eq(s.lhs, newRhs)) - return equationCollection.newWithAdditionalSubexpressions(result, []) + return equationCollection.copy(result) def addSubexpressionsForDivisions(equationCollection): @@ -80,4 +75,4 @@ def addSubexpressionsForDivisions(equationCollection): newSymbolGen = equationCollection.subexpressionSymbolNameGenerator substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)} - return equationCollection.newWithSubstitutionsApplied(substitutions, True) + return equationCollection.copyWithSubstitutionsApplied(substitutions, True) diff --git a/sympyextensions.py b/sympyextensions.py index e02b58be9..8daa67136 100644 --- a/sympyextensions.py +++ b/sympyextensions.py @@ -14,7 +14,11 @@ def fastSubs(term, subsDict): return expr paramList = [visit(a) for a in expr.args] return expr if not paramList else expr.func(*paramList) - return visit(term) + + if len(subsDict) == 0: + return term + else: + return visit(term) def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None): -- GitLab