equationcollection.py 13.1 KB
 Martin Bauer committed Feb 09, 2017 1 ``````import sympy as sp `````` Martin Bauer committed Feb 09, 2017 2 3 ``````from copy import deepcopy from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically `````` Martin Bauer committed Feb 09, 2017 4 5 `````` `````` Michael Kuron committed Feb 13, 2017 6 ``````class EquationCollection(object): `````` Martin Bauer committed Feb 09, 2017 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 `````` """ A collection of equations with subexpression definitions, also represented as equations, that are used in the main equations. EquationCollections can be passed to simplification methods. These simplification methods can change the subexpressions, but the number and left hand side of the main equations themselves is not altered. Additionally a dictionary of simplification hints is stored, which are set by the functions that create equation collections to transport information to the simplification system. :ivar mainEquations: list of sympy equations :ivar subexpressions: list of sympy equations defining subexpressions used in main equations :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are used by the simplification system. See documentation of the simplification rules for potentially required hints and their meaning. """ # ----------------------------------------- Creation --------------------------------------------------------------- `````` Martin Bauer committed Feb 09, 2017 24 `````` def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None): `````` Martin Bauer committed Feb 09, 2017 25 26 `````` self.mainEquations = equations self.subexpressions = subExpressions `````` Martin Bauer committed Feb 09, 2017 27 28 29 30 `````` if simplificationHints is None: simplificationHints = {} `````` Martin Bauer committed Feb 09, 2017 31 32 `````` self.simplificationHints = simplificationHints `````` Martin Bauer committed Feb 09, 2017 33 `````` if subexpressionSymbolNameGenerator is None: `````` Martin Bauer committed Feb 09, 2017 34 `````` self.subexpressionSymbolNameGenerator = SymbolGen() `````` Martin Bauer committed Feb 09, 2017 35 36 `````` else: self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator `````` Martin Bauer committed Feb 09, 2017 37 `````` `````` Martin Bauer committed Sep 16, 2017 38 39 40 41 `````` @property def mainTerms(self): return [] `````` Martin Bauer committed Feb 09, 2017 42 43 44 45 46 47 `````` def copy(self, mainEquations=None, subexpressions=None): res = deepcopy(self) if mainEquations is not None: res.mainEquations = mainEquations if subexpressions is not None: res.subexpressions = subexpressions `````` Martin Bauer committed Feb 09, 2017 48 `````` return res `````` Martin Bauer committed Feb 09, 2017 49 `````` `````` Martin Bauer committed Jun 09, 2017 50 51 `````` def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False, substituteOnLhs=True): `````` Martin Bauer committed Feb 09, 2017 52 53 54 55 `````` """ Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`. Substitutions are made in the subexpression terms and the main equations """ `````` Martin Bauer committed Jun 09, 2017 56 57 58 59 60 61 62 `````` if substituteOnLhs: newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions] newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations] else: newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions] newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations] `````` Martin Bauer committed Feb 09, 2017 63 `````` if addSubstitutionsAsSubexpressions: `````` Martin Bauer committed Feb 09, 2017 64 `````` newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions `````` Martin Bauer committed Mar 08, 2017 65 66 `````` newSubexpressions = sortEquationsTopologically(newSubexpressions) return self.copy(newEquations, newSubexpressions) `````` Martin Bauer committed Feb 09, 2017 67 68 `````` def addSimplificationHint(self, key, value): `````` Martin Bauer committed Feb 09, 2017 69 70 71 `````` """ Adds an entry to the simplificationHints dictionary, and checks that is does not exist yet """ `````` Martin Bauer committed Feb 09, 2017 72 73 74 75 76 77 78 `````` assert key not in self.simplificationHints, "This hint already exists" self.simplificationHints[key] = value # ---------------------------------------------- Properties ------------------------------------------------------- @property def allEquations(self): `````` Martin Bauer committed Feb 09, 2017 79 `````` """Subexpression and main equations in one sequence""" `````` Martin Bauer committed Feb 09, 2017 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 `````` return self.subexpressions + self.mainEquations @property def freeSymbols(self): """All symbols used in the equation collection, which have not been defined inside the equation system""" freeSymbols = set() for eq in self.allEquations: freeSymbols.update(eq.rhs.atoms(sp.Symbol)) return freeSymbols - self.boundSymbols @property def boundSymbols(self): """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined.""" boundSymbolsSet = set([eq.lhs for eq in self.allEquations]) assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \ "Not in SSA form - same symbol assigned multiple times" return boundSymbolsSet @property def definedSymbols(self): """All symbols that occur as left-hand-sides of the main equations""" return set([eq.lhs for eq in self.mainEquations]) `````` Martin Bauer committed Feb 09, 2017 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 `````` @property def operationCount(self): """See :func:`countNumberOfOperations` """ return countNumberOfOperations(self.allEquations) def get(self, symbols, fromMainEquationsOnly=False): """Return the equations which have symbols as left hand sides""" if not hasattr(symbols, "__len__"): symbols = list(symbols) symbols = set(symbols) if not fromMainEquationsOnly: eqsToSearchIn = self.allEquations else: eqsToSearchIn = self.mainEquations return [eq for eq in eqsToSearchIn if eq.lhs in symbols] `````` Martin Bauer committed Feb 09, 2017 121 122 123 124 125 126 `````` # ----------------------------------------- Display and Printing ------------------------------------------------- def _repr_html_(self): def makeHtmlEquationTable(equations): noBorder = 'style="border:none"' htmlTable = '' `````` Martin Bauer committed Feb 09, 2017 127 `````` line = '
' `````` Martin Bauer committed Feb 09, 2017 128 `````` for eq in equations: `````` Martin Bauer committed Feb 09, 2017 129 `````` formatDict = {'eq': sp.latex(eq), `````` Martin Bauer committed Feb 09, 2017 130 131 132 133 134 135 136 `````` 'nb': noBorder, } htmlTable += line.format(**formatDict) htmlTable += "" return htmlTable result = "" if len(self.subexpressions) > 0: `````` Martin Bauer committed Feb 09, 2017 137 `````` result += "
" `````` Martin Bauer committed Feb 09, 2017 138 `````` result += makeHtmlEquationTable(self.subexpressions) `````` Martin Bauer committed Feb 09, 2017 139 `````` result += "
" `````` Martin Bauer committed Feb 09, 2017 140 141 142 143 144 145 `````` result += makeHtmlEquationTable(self.mainEquations) return result def __repr__(self): return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations]) `````` Martin Bauer committed Feb 09, 2017 146 147 148 149 150 151 152 153 154 `````` def __str__(self): result = "Subexpressions\n" for eq in self.subexpressions: result += str(eq) + "\n" result += "Main Equations\n" for eq in self.mainEquations: result += str(eq) + "\n" return result `````` Martin Bauer committed Feb 09, 2017 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 `````` # ------------------------------------- Manipulation ------------------------------------------------------------ def merge(self, other): """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" ownDefs = set([e.lhs for e in self.mainEquations]) otherDefs = set([e.lhs for e in other.mainEquations]) assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols" ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions} substitutionDict = {} processedOtherSubexpressionEquations = [] for otherSubexpressionEq in other.subexpressions: if otherSubexpressionEq.lhs in ownSubexpressionSymbols: if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]: continue # exact the same subexpression equation exists already else: # different definition - a new name has to be introduced `````` Martin Bauer committed Feb 09, 2017 173 `````` newLhs = next(self.subexpressionSymbolNameGenerator) `````` Martin Bauer committed Feb 09, 2017 174 175 176 177 178 `````` newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict)) processedOtherSubexpressionEquations.append(newEq) substitutionDict[otherSubexpressionEq.lhs] = newLhs else: processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict)) `````` Martin Bauer committed Feb 09, 2017 179 180 181 `````` processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations] return self.copy(self.mainEquations + processedOtherMainEquations, `````` Martin Bauer committed Feb 09, 2017 182 `````` self.subexpressions + processedOtherSubexpressionEquations) `````` Martin Bauer committed Feb 09, 2017 183 `````` `````` Martin Bauer committed Feb 09, 2017 184 185 `````` def getDependentSymbols(self, symbolSequence): """Returns a list of symbols that depend on the passed symbols.""" `````` Martin Bauer committed Feb 09, 2017 186 `````` `````` Martin Bauer committed Feb 09, 2017 187 `````` queue = list(symbolSequence) `````` Martin Bauer committed Feb 09, 2017 188 189 190 191 `````` def addSymbolsFromExpr(expr): dependentSymbols = expr.atoms(sp.Symbol) for ds in dependentSymbols: `````` Martin Bauer committed Feb 09, 2017 192 `````` queue.append(ds) `````` Martin Bauer committed Feb 09, 2017 193 `````` `````` Martin Bauer committed Feb 09, 2017 194 195 `````` handledSymbols = set() eqMap = {e.lhs: e.rhs for e in self.allEquations} `````` Martin Bauer committed Feb 09, 2017 196 197 198 `````` while len(queue) > 0: e = queue.pop(0) `````` Martin Bauer committed Feb 09, 2017 199 `````` if e in handledSymbols: `````` Martin Bauer committed Feb 09, 2017 200 `````` continue `````` Martin Bauer committed Feb 09, 2017 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 `````` if e in eqMap: addSymbolsFromExpr(eqMap[e]) handledSymbols.add(e) return handledSymbols def extract(self, symbolsToExtract): """ Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and only the necessary subexpressions that are used in these equations """ symbolsToExtract = set(symbolsToExtract) dependentSymbols = self.getDependentSymbols(symbolsToExtract) newEquations = [] for eq in self.allEquations: if eq.lhs in symbolsToExtract: newEquations.append(eq) `````` Martin Bauer committed Feb 09, 2017 218 `````` `````` Martin Bauer committed Feb 09, 2017 219 `````` newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract] `````` Martin Bauer committed Feb 09, 2017 220 221 222 223 224 225 226 227 `````` return EquationCollection(newEquations, newSubExpr) def newWithoutUnusedSubexpressions(self): """Returns a new equation collection containing only the subexpressions that are used/referenced in the equations""" allLhs = [eq.lhs for eq in self.mainEquations] return self.extract(allLhs) `````` Martin Bauer committed Mar 22, 2018 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 `````` def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True): if lhs is None: lhs = sp.Dummy() eq = sp.Eq(lhs, rhs) self.subexpressions.append(eq) if topologicalSort: self.topologicalSort(subexpressions=True, mainEquations=False) return lhs def topologicalSort(self, subexpressions=True, mainEquations=True): if subexpressions: self.subexpressions = sortEquationsTopologically(self.subexpressions) if mainEquations: self.mainEquations = sortEquationsTopologically(self.mainEquations) `````` Martin Bauer committed Feb 09, 2017 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 `````` def insertSubexpression(self, symbol): newSubexpressions = [] subsDict = None for se in self.subexpressions: if se.lhs == symbol: subsDict = {se.lhs: se.rhs} else: newSubexpressions.append(se) if subsDict is None: return self newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions] newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations] return self.copy(newEqs, newSubexpressions) `````` Martin Bauer committed Feb 09, 2017 258 `````` def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): `````` Martin Bauer committed Feb 09, 2017 259 260 `````` """Returns a new equation collection by inserting all subexpressions into the main equations""" if len(self.subexpressions) == 0: `````` Martin Bauer committed Feb 09, 2017 261 262 263 264 265 266 267 268 269 270 271 `````` return self.copy() subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep) keptSubexpressions = [] if self.subexpressions[0].lhs in subexpressionSymbolsToKeep: subsDict = {} keptSubexpressions = self.subexpressions[0] else: subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} `````` Martin Bauer committed Feb 09, 2017 272 273 274 `````` subExpr = [e for e in self.subexpressions] for i in range(1, len(subExpr)): subExpr[i] = fastSubs(subExpr[i], subsDict) `````` Martin Bauer committed Feb 09, 2017 275 276 277 278 `````` if subExpr[i].lhs in subexpressionSymbolsToKeep: keptSubexpressions.append(subExpr[i]) else: subsDict[subExpr[i].lhs] = subExpr[i].rhs `````` Martin Bauer committed Feb 09, 2017 279 280 `````` newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations] `````` Martin Bauer committed Feb 09, 2017 281 `````` return self.copy(newEq, keptSubexpressions) `````` Martin Bauer committed Feb 09, 2017 282 283 284 285 286 287 288 289 `````` def lambdify(self, symbols, module=None, fixedSymbols={}): """ Returns a function to evaluate this equation collection :param symbols: symbol(s) which are the parameter for the created function :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy' :param fixedSymbols: dictionary with substitutions, that are applied before lambdification """ `````` Martin Bauer committed Feb 09, 2017 290 `````` eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations `````` Martin Bauer committed Feb 09, 2017 291 `````` lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs} `````` Martin Bauer committed Feb 09, 2017 292 293 294 295 296 `````` def f(*args, **kwargs): return {s: f(*args, **kwargs) for s, f in lambdas.items()} return f `````` Martin Bauer committed Oct 10, 2017 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 `````` class SymbolGen: def __init__(self): self._ctr = 0 def __iter__(self): return self def __next__(self): self._ctr += 1 return sp.Symbol("xi_" + str(self._ctr)) def next(self): return self.__next__()``````