equationcollection.py 13.1 KB
Newer Older
1
import sympy as sp
2
3
from copy import deepcopy
from pystencils.sympyextensions import fastSubs, countNumberOfOperations, sortEquationsTopologically
4
5


Michael Kuron's avatar
Michael Kuron committed
6
class EquationCollection(object):
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    """
    A collection of equations with subexpression definitions, also represented as equations,
    that are used in the main equations. EquationCollections can be passed to simplification methods.
    These simplification methods can change the subexpressions, but the number and
    left hand side of the main equations themselves is not altered.
    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
    equation collections to transport information to the simplification system.

    :ivar mainEquations: list of sympy equations
    :ivar subexpressions: list of sympy equations defining subexpressions used in main equations
    :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are
                               used by the simplification system. See documentation of the simplification rules for
                               potentially required hints and their meaning.
    """

    # ----------------------------------------- Creation ---------------------------------------------------------------

24
    def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
25
26
        self.mainEquations = equations
        self.subexpressions = subExpressions
27
28
29
30

        if simplificationHints is None:
            simplificationHints = {}

31
32
        self.simplificationHints = simplificationHints

33
        if subexpressionSymbolNameGenerator is None:
34
            self.subexpressionSymbolNameGenerator = SymbolGen()
35
36
        else:
            self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
37

Martin Bauer's avatar
Martin Bauer committed
38
39
40
41
    @property
    def mainTerms(self):
        return []

42
43
44
45
46
47
    def copy(self, mainEquations=None, subexpressions=None):
        res = deepcopy(self)
        if mainEquations is not None:
            res.mainEquations = mainEquations
        if subexpressions is not None:
            res.subexpressions = subexpressions
48
        return res
49

50
51
    def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False,
                                     substituteOnLhs=True):
52
53
54
55
        """
        Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
        Substitutions are made in the subexpression terms and the main equations
        """
56
57
58
59
60
61
62
        if substituteOnLhs:
            newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
            newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
        else:
            newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.subexpressions]
            newEquations = [sp.Eq(eq.lhs, fastSubs(eq.rhs, substitutionDict)) for eq in self.mainEquations]

63
        if addSubstitutionsAsSubexpressions:
64
            newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
65
66
            newSubexpressions = sortEquationsTopologically(newSubexpressions)
        return self.copy(newEquations, newSubexpressions)
67
68

    def addSimplificationHint(self, key, value):
69
70
71
        """
        Adds an entry to the simplificationHints dictionary, and checks that is does not exist yet
        """
72
73
74
75
76
77
78
        assert key not in self.simplificationHints, "This hint already exists"
        self.simplificationHints[key] = value

    # ---------------------------------------------- Properties  -------------------------------------------------------

    @property
    def allEquations(self):
79
        """Subexpression and main equations in one sequence"""
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        return self.subexpressions + self.mainEquations

    @property
    def freeSymbols(self):
        """All symbols used in the equation collection, which have not been defined inside the equation system"""
        freeSymbols = set()
        for eq in self.allEquations:
            freeSymbols.update(eq.rhs.atoms(sp.Symbol))
        return freeSymbols - self.boundSymbols

    @property
    def boundSymbols(self):
        """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined."""
        boundSymbolsSet = set([eq.lhs for eq in self.allEquations])
        assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \
            "Not in SSA form - same symbol assigned multiple times"
        return boundSymbolsSet

    @property
    def definedSymbols(self):
        """All symbols that occur as left-hand-sides of the main equations"""
        return set([eq.lhs for eq in self.mainEquations])

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    @property
    def operationCount(self):
        """See :func:`countNumberOfOperations` """
        return countNumberOfOperations(self.allEquations)

    def get(self, symbols, fromMainEquationsOnly=False):
        """Return the equations which have symbols as left hand sides"""
        if not hasattr(symbols, "__len__"):
            symbols = list(symbols)
        symbols = set(symbols)

        if not fromMainEquationsOnly:
            eqsToSearchIn = self.allEquations
        else:
            eqsToSearchIn = self.mainEquations

        return [eq for eq in eqsToSearchIn if eq.lhs in symbols]

121
122
123
124
125
126
    # ----------------------------------------- Display and Printing   -------------------------------------------------

    def _repr_html_(self):
        def makeHtmlEquationTable(equations):
            noBorder = 'style="border:none"'
            htmlTable = '<table style="border:none; width: 100%; ">'
127
            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
128
            for eq in equations:
129
                formatDict = {'eq': sp.latex(eq),
130
131
132
133
134
135
136
                              'nb': noBorder, }
                htmlTable += line.format(**formatDict)
            htmlTable += "</table>"
            return htmlTable

        result = ""
        if len(self.subexpressions) > 0:
137
            result += "<div>Subexpressions:</div>"
138
            result += makeHtmlEquationTable(self.subexpressions)
139
        result += "<div>Main Equations:</div>"
140
141
142
143
144
145
        result += makeHtmlEquationTable(self.mainEquations)
        return result

    def __repr__(self):
        return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations])

146
147
148
149
150
151
152
153
154
    def __str__(self):
        result = "Subexpressions\n"
        for eq in self.subexpressions:
            result += str(eq) + "\n"
        result += "Main Equations\n"
        for eq in self.mainEquations:
            result += str(eq) + "\n"
        return result

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    # -------------------------------------   Manipulation  ------------------------------------------------------------

    def merge(self, other):
        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
        ownDefs = set([e.lhs for e in self.mainEquations])
        otherDefs = set([e.lhs for e in other.mainEquations])
        assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols"

        ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions}
        substitutionDict = {}

        processedOtherSubexpressionEquations = []
        for otherSubexpressionEq in other.subexpressions:
            if otherSubexpressionEq.lhs in ownSubexpressionSymbols:
                if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]:
                    continue  # exact the same subexpression equation exists already
                else:
                    # different definition - a new name has to be introduced
Martin Bauer's avatar
Martin Bauer committed
173
                    newLhs = next(self.subexpressionSymbolNameGenerator)
174
175
176
177
178
                    newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
                    processedOtherSubexpressionEquations.append(newEq)
                    substitutionDict[otherSubexpressionEq.lhs] = newLhs
            else:
                processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
Martin Bauer's avatar
Martin Bauer committed
179
180
181

        processedOtherMainEquations = [fastSubs(eq, substitutionDict) for eq in other.mainEquations]
        return self.copy(self.mainEquations + processedOtherMainEquations,
182
                         self.subexpressions + processedOtherSubexpressionEquations)
183

184
185
    def getDependentSymbols(self, symbolSequence):
        """Returns a list of symbols that depend on the passed symbols."""
186

187
        queue = list(symbolSequence)
188
189
190
191

        def addSymbolsFromExpr(expr):
            dependentSymbols = expr.atoms(sp.Symbol)
            for ds in dependentSymbols:
192
                queue.append(ds)
193

194
195
        handledSymbols = set()
        eqMap = {e.lhs: e.rhs for e in self.allEquations}
196
197
198

        while len(queue) > 0:
            e = queue.pop(0)
199
            if e in handledSymbols:
200
                continue
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            if e in eqMap:
                addSymbolsFromExpr(eqMap[e])
            handledSymbols.add(e)

        return handledSymbols

    def extract(self, symbolsToExtract):
        """
        Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
        only the necessary subexpressions that are used in these equations
        """
        symbolsToExtract = set(symbolsToExtract)
        dependentSymbols = self.getDependentSymbols(symbolsToExtract)
        newEquations = []
        for eq in self.allEquations:
            if eq.lhs in symbolsToExtract:
                newEquations.append(eq)
218

219
        newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
220
221
222
223
224
225
226
227
        return EquationCollection(newEquations, newSubExpr)

    def newWithoutUnusedSubexpressions(self):
        """Returns a new equation collection containing only the subexpressions that
        are used/referenced in the equations"""
        allLhs = [eq.lhs for eq in self.mainEquations]
        return self.extract(allLhs)

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def appendToSubexpressions(self, rhs, lhs=None, topologicalSort=True):
        if lhs is None:
            lhs = sp.Dummy()
        eq = sp.Eq(lhs, rhs)
        self.subexpressions.append(eq)
        if topologicalSort:
            self.topologicalSort(subexpressions=True, mainEquations=False)
        return lhs

    def topologicalSort(self, subexpressions=True, mainEquations=True):
        if subexpressions:
            self.subexpressions = sortEquationsTopologically(self.subexpressions)
        if mainEquations:
            self.mainEquations = sortEquationsTopologically(self.mainEquations)

Martin Bauer's avatar
Martin Bauer committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def insertSubexpression(self, symbol):
        newSubexpressions = []
        subsDict = None
        for se in self.subexpressions:
            if se.lhs == symbol:
                subsDict = {se.lhs: se.rhs}
            else:
                newSubexpressions.append(se)
        if subsDict is None:
            return self

        newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
        newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations]
        return self.copy(newEqs, newSubexpressions)

258
    def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
259
260
        """Returns a new equation collection by inserting all subexpressions into the main equations"""
        if len(self.subexpressions) == 0:
261
262
263
264
265
266
267
268
269
270
271
            return self.copy()

        subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep)

        keptSubexpressions = []
        if self.subexpressions[0].lhs in subexpressionSymbolsToKeep:
            subsDict = {}
            keptSubexpressions = self.subexpressions[0]
        else:
            subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}

272
273
274
        subExpr = [e for e in self.subexpressions]
        for i in range(1, len(subExpr)):
            subExpr[i] = fastSubs(subExpr[i], subsDict)
275
276
277
278
            if subExpr[i].lhs in subexpressionSymbolsToKeep:
                keptSubexpressions.append(subExpr[i])
            else:
                subsDict[subExpr[i].lhs] = subExpr[i].rhs
279
280

        newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
281
        return self.copy(newEq, keptSubexpressions)
282
283
284
285
286
287
288
289

    def lambdify(self, symbols, module=None, fixedSymbols={}):
        """
        Returns a function to evaluate this equation collection
        :param symbols: symbol(s) which are the parameter for the created function
        :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
        :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
        """
290
        eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
291
        lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
292
293
294
295
296

        def f(*args, **kwargs):
            return {s: f(*args, **kwargs) for s, f in lambdas.items()}

        return f
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311


class SymbolGen:
    def __init__(self):
        self._ctr = 0

    def __iter__(self):
        return self

    def __next__(self):
        self._ctr += 1
        return sp.Symbol("xi_" + str(self._ctr))

    def next(self):
        return self.__next__()