equationcollection.py 11.2 KB
Newer Older
1
import sympy as sp
2
from pystencils.sympyextensions import fastSubs, countNumberOfOperations
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


class EquationCollection:
    """
    A collection of equations with subexpression definitions, also represented as equations,
    that are used in the main equations. EquationCollections can be passed to simplification methods.
    These simplification methods can change the subexpressions, but the number and
    left hand side of the main equations themselves is not altered.
    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
    equation collections to transport information to the simplification system.

    :ivar mainEquations: list of sympy equations
    :ivar subexpressions: list of sympy equations defining subexpressions used in main equations
    :ivar simplificationHints: dictionary that is used to annotate the equation collection with hints that are
                               used by the simplification system. See documentation of the simplification rules for
                               potentially required hints and their meaning.
    """

    # ----------------------------------------- Creation ---------------------------------------------------------------

23
    def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        self.mainEquations = equations
        self.subexpressions = subExpressions
        self.simplificationHints = simplificationHints

        def symbolGen():
            """Use this generator to create new unused symbols for subexpressions"""
            counter = 0
            while True:
                counter += 1
                newSymbol = sp.Symbol("xi_" + str(counter))
                if newSymbol in self.boundSymbols:
                    continue
                yield newSymbol

38
39
40
41
        if subexpressionSymbolNameGenerator is None:
            self.subexpressionSymbolNameGenerator = symbolGen()
        else:
            self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
42

43
44
45
46
47
48
    def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions):
        """
        Returns a new equation collection, that has `newEquations` as mainEquations.
        The `additionalSubExpressions` are appended to the existing subexpressions.
        Simplifications hints are copied over.
        """
49
        assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed"
50
51
52
53
54
        res = EquationCollection(newEquations,
                                 self.subexpressions + additionalSubExpressions,
                                 self.simplificationHints)
        res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
        return res
55

56
57
58
59
60
    def newWithSubstitutionsApplied(self, substitutionDict):
        """
        Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
        Substitutions are made in the subexpression terms and the main equations
        """
61
62
        newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
        newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
63
64
65
        res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
        res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
        return res
66
67

    def addSimplificationHint(self, key, value):
68
69
70
        """
        Adds an entry to the simplificationHints dictionary, and checks that is does not exist yet
        """
71
72
73
74
75
76
77
        assert key not in self.simplificationHints, "This hint already exists"
        self.simplificationHints[key] = value

    # ---------------------------------------------- Properties  -------------------------------------------------------

    @property
    def allEquations(self):
78
        """Subexpression and main equations in one sequence"""
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        return self.subexpressions + self.mainEquations

    @property
    def freeSymbols(self):
        """All symbols used in the equation collection, which have not been defined inside the equation system"""
        freeSymbols = set()
        for eq in self.allEquations:
            freeSymbols.update(eq.rhs.atoms(sp.Symbol))
        return freeSymbols - self.boundSymbols

    @property
    def boundSymbols(self):
        """Set of all symbols which occur on left-hand-sides i.e. all symbols which are defined."""
        boundSymbolsSet = set([eq.lhs for eq in self.allEquations])
        assert len(boundSymbolsSet) == len(self.subexpressions) + len(self.mainEquations), \
            "Not in SSA form - same symbol assigned multiple times"
        return boundSymbolsSet

    @property
    def definedSymbols(self):
        """All symbols that occur as left-hand-sides of the main equations"""
        return set([eq.lhs for eq in self.mainEquations])

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    @property
    def operationCount(self):
        """See :func:`countNumberOfOperations` """
        return countNumberOfOperations(self.allEquations)

    def get(self, symbols, fromMainEquationsOnly=False):
        """Return the equations which have symbols as left hand sides"""
        if not hasattr(symbols, "__len__"):
            symbols = list(symbols)
        symbols = set(symbols)

        if not fromMainEquationsOnly:
            eqsToSearchIn = self.allEquations
        else:
            eqsToSearchIn = self.mainEquations

        return [eq for eq in eqsToSearchIn if eq.lhs in symbols]

120
121
122
123
124
125
    # ----------------------------------------- Display and Printing   -------------------------------------------------

    def _repr_html_(self):
        def makeHtmlEquationTable(equations):
            noBorder = 'style="border:none"'
            htmlTable = '<table style="border:none; width: 100%; ">'
126
            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
127
            for eq in equations:
128
                formatDict = {'eq': sp.latex(eq),
129
130
131
132
133
134
135
                              'nb': noBorder, }
                htmlTable += line.format(**formatDict)
            htmlTable += "</table>"
            return htmlTable

        result = ""
        if len(self.subexpressions) > 0:
136
            result += "<div>Subexpressions:</div>"
137
            result += makeHtmlEquationTable(self.subexpressions)
138
        result += "<div>Main Equations:</div>"
139
140
141
142
143
144
        result += makeHtmlEquationTable(self.mainEquations)
        return result

    def __repr__(self):
        return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.mainEquations])

145
146
147
148
149
150
151
152
153
    def __str__(self):
        result = "Subexpressions\n"
        for eq in self.subexpressions:
            result += str(eq) + "\n"
        result += "Main Equations\n"
        for eq in self.mainEquations:
            result += str(eq) + "\n"
        return result

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    # -------------------------------------   Manipulation  ------------------------------------------------------------

    def merge(self, other):
        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
        ownDefs = set([e.lhs for e in self.mainEquations])
        otherDefs = set([e.lhs for e in other.mainEquations])
        assert len(ownDefs.intersection(otherDefs)) == 0, "Cannot merge, since both collection define the same symbols"

        ownSubexpressionSymbols = {e.lhs: e.rhs for e in self.subexpressions}
        substitutionDict = {}

        processedOtherSubexpressionEquations = []
        for otherSubexpressionEq in other.subexpressions:
            if otherSubexpressionEq.lhs in ownSubexpressionSymbols:
                if otherSubexpressionEq.rhs == ownSubexpressionSymbols[otherSubexpressionEq.lhs]:
                    continue  # exact the same subexpression equation exists already
                else:
                    # different definition - a new name has to be introduced
                    newLhs = self.subexpressionSymbolNameGenerator()
                    newEq = sp.Eq(newLhs, fastSubs(otherSubexpressionEq.rhs, substitutionDict))
                    processedOtherSubexpressionEquations.append(newEq)
                    substitutionDict[otherSubexpressionEq.lhs] = newLhs
            else:
                processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
        return EquationCollection(self.mainEquations + other.mainEquations,
                                  self.subexpressions + processedOtherSubexpressionEquations)

    def extract(self, symbolsToExtract):
        """
        Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
        only the necessary subexpressions that are used in these equations
        """
        symbolsToExtract = set(symbolsToExtract)
        newEquations = []

        subexprMap = {e.lhs: e.rhs for e in self.subexpressions}
        handledSymbols = set()
        queue = []

        def addSymbolsFromExpr(expr):
            dependentSymbols = expr.atoms(sp.Symbol)
            for ds in dependentSymbols:
                if ds not in handledSymbols:
                    queue.append(ds)
                    handledSymbols.add(ds)

200
        for eq in self.allEquations:
201
202
203
204
205
206
207
208
209
210
211
            if eq.lhs in symbolsToExtract:
                newEquations.append(eq)
                addSymbolsFromExpr(eq.rhs)

        while len(queue) > 0:
            e = queue.pop(0)
            if e not in subexprMap:
                continue
            else:
                addSymbolsFromExpr(subexprMap[e])

212
        newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract]
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        return EquationCollection(newEquations, newSubExpr)

    def newWithoutUnusedSubexpressions(self):
        """Returns a new equation collection containing only the subexpressions that
        are used/referenced in the equations"""
        allLhs = [eq.lhs for eq in self.mainEquations]
        return self.extract(allLhs)

    def insertSubexpressions(self):
        """Returns a new equation collection by inserting all subexpressions into the main equations"""
        if len(self.subexpressions) == 0:
            return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints)
        subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
        subExpr = [e for e in self.subexpressions]
        for i in range(1, len(subExpr)):
            subExpr[i] = fastSubs(subExpr[i], subsDict)
            subsDict[subExpr[i].lhs] = subExpr[i].rhs

        newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
        return EquationCollection(newEq, [], self.simplificationHints)

    def lambdify(self, symbols, module=None, fixedSymbols={}):
        """
        Returns a function to evaluate this equation collection
        :param symbols: symbol(s) which are the parameter for the created function
        :param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
        :param fixedSymbols: dictionary with substitutions, that are applied before lambdification
        """
241
        eqs = self.newWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
242
        lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
243
244
245
246
247

        def f(*args, **kwargs):
            return {s: f(*args, **kwargs) for s, f in lambdas.items()}

        return f