equationcollection.py 11.2 KB
 Martin Bauer committed Feb 09, 2017 1 ``````import sympy as sp `````` Martin Bauer committed Feb 09, 2017 2 ``````from pystencils.sympyextensions import fastSubs, countNumberOfOperations `````` Martin Bauer committed Feb 09, 2017 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 `````` class EquationCollection: """ A collection of equations with subexpression definitions, also represented as equations, that are used in the main equations. EquationCollections can be passed to simplification methods. These simplification methods can change the subexpressions, but the number and left hand side of the main equations themselves is not altered. Additionally a dictionary of simplification hints is stored, which are set by the functions that create equation collections to transport information to the simplification system. :ivar mainEquations: list of sympy equations :ivar subexpressions: list of sympy equations defining subexpressions used in main equations :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are used by the simplification system. See documentation of the simplification rules for potentially required hints and their meaning. """ # ----------------------------------------- Creation --------------------------------------------------------------- `````` Martin Bauer committed Feb 09, 2017 23 `````` def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None): `````` Martin Bauer committed Feb 09, 2017 24 25 26 27 28 29 30 31 32 33 34 35 36 37 `````` self.mainEquations = equations self.subexpressions = subExpressions self.simplificationHints = simplificationHints def symbolGen(): """Use this generator to create new unused symbols for subexpressions""" counter = 0 while True: counter += 1 newSymbol = sp.Symbol("xi_" + str(counter)) if newSymbol in self.boundSymbols: continue yield newSymbol `````` Martin Bauer committed Feb 09, 2017 38 39 40 41 `````` if subexpressionSymbolNameGenerator is None: self.subexpressionSymbolNameGenerator = symbolGen() else: self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator `````` Martin Bauer committed Feb 09, 2017 42 `````` `````` Martin Bauer committed Feb 09, 2017 43 44 45 46 47 48 `````` def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions): """ Returns a new equation collection, that has `newEquations` as mainEquations. The `additionalSubExpressions` are appended to the existing subexpressions. Simplifications hints are copied over. """ `````` Martin Bauer committed Feb 09, 2017 49 `````` assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed" `````` Martin Bauer committed Feb 09, 2017 50 51 52 53 54 `````` res = EquationCollection(newEquations, self.subexpressions + additionalSubExpressions, self.simplificationHints) res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator return res `````` Martin Bauer committed Feb 09, 2017 55 `````` `````` Martin Bauer committed Feb 09, 2017 56 57 58 59 60 `````` def newWithSubstitutionsApplied(self, substitutionDict): """ Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`. Substitutions are made in the subexpression terms and the main equations """ `````` Martin Bauer committed Feb 09, 2017 61 62 `````` newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] `````` Martin Bauer committed Feb 09, 2017 63 64 65 `````` res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints) res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator return res `````` Martin Bauer committed Feb 09, 2017 66 67 `````` def addSimplificationHint(self, key, value): `````` Martin Bauer committed Feb 09, 2017 68 69 70 `````` """ Adds an entry to the simplificationHints dictionary, and checks that is does not exist yet """ `````` Martin Bauer committed Feb 09, 2017 71 72 73 74 75 76 77 `````` assert key not in self.simplificationHints, "This hint already exists" self.simplificationHints[key] = value # ---------------------------------------------- Properties ------------------------------------------------------- @property def allEquations(self): `````` Martin Bauer committed Feb 09, 2017 78 `````` """Subexpression and main equations in one sequence""" `````` Martin Bauer committed Feb 09, 2017 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 `````` return self.subexpressions + self.mainEquations @property def freeSymbols(self): """All symbols used in the equation collection, which have not been defined inside the equation system""" freeSymbols = set() for eq in self.allEquations: freeSymbols.update(eq.rhs.atoms(sp.Symbol)) return freeSymbols - self.boundSymbols @property def boundSymbols(self): """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \ "Not in SSA form - same symbol assigned multiple times" return boundSymbolsSet @property def definedSymbols(self): """All symbols that occur as left-hand-sides of the main equations""" return set([eq.lhs for eq in self.mainEquations]) `````` Martin Bauer committed Feb 09, 2017 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 `````` @property def operationCount(self): """See :func:`countNumberOfOperations` """ return countNumberOfOperations(self.allEquations) def get(self, symbols, fromMainEquationsOnly=False): """Return the equations which have symbols as left hand sides""" if not hasattr(symbols, "__len__"): symbols = list(symbols) symbols = set(symbols) if not fromMainEquationsOnly: eqsToSearchIn = self.allEquations else: eqsToSearchIn = self.mainEquations return [eq for eq in eqsToSearchIn if eq.lhs in symbols] `````` Martin Bauer committed Feb 09, 2017 120 121 122 123 124 125 `````` # ----------------------------------------- Display and Printing ------------------------------------------------- def _repr_html_(self): def makeHtmlEquationTable(equations): noBorder = 'style="border:none"' htmlTable = '' `````` Martin Bauer committed Feb 09, 2017 126 `````` line = '
' `````` Martin Bauer committed Feb 09, 2017 127 `````` for eq in equations: `````` Martin Bauer committed Feb 09, 2017 128 `````` formatDict = {'eq': sp.latex(eq), `````` Martin Bauer committed Feb 09, 2017 129 130 131 132 133 134 135 `````` 'nb': noBorder, } htmlTable += line.format(**formatDict) htmlTable += "" return htmlTable result = "" if len(self.subexpressions) > 0: `````` Martin Bauer committed Feb 09, 2017 136 `````` result += "
" `````` Martin Bauer committed Feb 09, 2017 137 `````` result += makeHtmlEquationTable(self.subexpressions) `````` Martin Bauer committed Feb 09, 2017 138 `````` result += "
" `````` Martin Bauer committed Feb 09, 2017 139 140 141 142 143 144 `````` result += makeHtmlEquationTable(self.mainEquations) return result def __repr__(self): return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) `````` Martin Bauer committed Feb 09, 2017 145 146 147 148 149 150 151 152 153 `````` def __str__(self): result = "Subexpressions\n" for eq in self.subexpressions: result += str(eq) + "\n" result += "Main Equations\n" for eq in self.mainEquations: result += str(eq) + "\n" return result `````` Martin Bauer committed Feb 09, 2017 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 `````` # ------------------------------------- Manipulation ------------------------------------------------------------ def merge(self, other): """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" ownDefs = set([e.lhs for e in self.mainEquations]) otherDefs = set([e.lhs for e in other.mainEquations]) assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} substitutionDict = {} processedOtherSubexpressionEquations = [] for otherSubexpressionEq in other.subexpressions: if otherSubexpressionEq.lhs in ownSubexpressionSymbols: if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]: continue # exact the same subexpression equation exists already else: # different definition - a new name has to be introduced newLhs = self.subexpressionSymbolNameGenerator() newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) processedOtherSubexpressionEquations.append(newEq) substitutionDict[otherSubexpressionEq.lhs] = newLhs else: processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) return EquationCollection(self.mainEquations + other.mainEquations, self.subexpressions + processedOtherSubexpressionEquations) def extract(self, symbolsToExtract): """ Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and only the necessary subexpressions that are used in these equations """ symbolsToExtract = set(symbolsToExtract) newEquations = [] subexprMap = {e.lhs: e.rhs for e in self.subexpressions} handledSymbols = set() queue = [] def addSymbolsFromExpr(expr): dependentSymbols = expr.atoms(sp.Symbol) for ds in dependentSymbols: if ds not in handledSymbols: queue.append(ds) handledSymbols.add(ds) `````` Martin Bauer committed Feb 09, 2017 200 `````` for eq in self.allEquations: `````` Martin Bauer committed Feb 09, 2017 201 202 203 204 205 206 207 208 209 210 211 `````` if eq.lhs in symbolsToExtract: newEquations.append(eq) addSymbolsFromExpr(eq.rhs) while len(queue) > 0: e = queue.pop(0) if e not in subexprMap: continue else: addSymbolsFromExpr(subexprMap[e]) `````` Martin Bauer committed Feb 09, 2017 212 `````` newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract] `````` Martin Bauer committed Feb 09, 2017 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 `````` return EquationCollection(newEquations, newSubExpr) def newWithoutUnusedSubexpressions(self): """Returns a new equation collection containing only the subexpressions that are used/referenced in the equations""" allLhs = [eq.lhs for eq in self.mainEquations] return self.extract(allLhs) def insertSubexpressions(self): """Returns a new equation collection by inserting all subexpressions into the main equations""" if len(self.subexpressions) == 0: return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints) subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} subExpr = [e for e in self.subexpressions] for i in range(1, len(subExpr)): subExpr[i] = fastSubs(subExpr[i], subsDict) subsDict[subExpr[i].lhs] = subExpr[i].rhs newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] return EquationCollection(newEq, [], self.simplificationHints) def lambdify(self, symbols, module=None, fixedSymbols={}): """ Returns a function to evaluate this equation collection :param symbols: symbol(s) which are the parameter for the created function :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' :param fixedSymbols: dictionary with substitutions, that are applied before lambdification """ `````` Martin Bauer committed Feb 09, 2017 241 `````` eqs = self.newWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations `````` Martin Bauer committed Feb 09, 2017 242 `````` lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} `````` Martin Bauer committed Feb 09, 2017 243 244 245 246 247 `````` def f(*args, **kwargs): return {s: f(*args, **kwargs) for s, f in lambdas.items()} return f``````