assignment_collection.py 19.4 KB
Newer Older
1
from copy import copy
Martin Bauer's avatar
Martin Bauer committed
2
3
4
5
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union

import sympy as sp

6
from pystencils.assignment import Assignment
7
from pystencils.simp.simplifications import (
Martin Bauer's avatar
Martin Bauer committed
8
    sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
9
10
11
from pystencils.sympyextensions import count_operations, fast_subs


Martin Bauer's avatar
Martin Bauer committed
12
class AssignmentCollection:
13
    """
Martin Bauer's avatar
Martin Bauer committed
14
    A collection of equations with subexpression definitions, also represented as assignments,
15
    that are used in the main equations. AssignmentCollection can be passed to simplification methods.
16
17
18
    These simplification methods can change the subexpressions, but the number and
    left hand side of the main equations themselves is not altered.
    Additionally a dictionary of simplification hints is stored, which are set by the functions that create
19
    assignment collections to transport information to the simplification system.
20

Martin Bauer's avatar
Martin Bauer committed
21
22
23
    Attributes:
        main_assignments: list of assignments
        subexpressions: list of assignments defining subexpressions used in main equations
24
        simplification_hints: dict that is used to annotate the assignment collection with hints that are
Martin Bauer's avatar
Martin Bauer committed
25
26
27
28
29
                              used by the simplification system. See documentation of the simplification rules for
                              potentially required hints and their meaning.
        subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added
                                        used to get new symbols that are unique for this AssignmentCollection

30
31
    """

Martin Bauer's avatar
Martin Bauer committed
32
    # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
33

34
    def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
35
                 subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {},
Martin Bauer's avatar
Martin Bauer committed
36
37
                 simplification_hints: Optional[Dict[str, Any]] = None,
                 subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
38
39
40
41
42
43
44
        if isinstance(main_assignments, Dict):
            main_assignments = [Assignment(k, v)
                                for k, v in main_assignments.items()]
        if isinstance(subexpressions, Dict):
            subexpressions = [Assignment(k, v)
                              for k, v in subexpressions.items()]

Martin Bauer's avatar
Martin Bauer committed
45
46
        self.main_assignments = main_assignments
        self.subexpressions = subexpressions
47

Martin Bauer's avatar
Martin Bauer committed
48
49
        if simplification_hints is None:
            simplification_hints = {}
50

Martin Bauer's avatar
Martin Bauer committed
51
        self.simplification_hints = simplification_hints
52

Martin Bauer's avatar
Martin Bauer committed
53
54
        if subexpression_symbol_generator is None:
            self.subexpression_symbol_generator = SymbolGen()
55
        else:
Martin Bauer's avatar
Martin Bauer committed
56
            self.subexpression_symbol_generator = subexpression_symbol_generator
57

Martin Bauer's avatar
Martin Bauer committed
58
59
60
61
    def add_simplification_hint(self, key: str, value: Any) -> None:
        """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
        assert key not in self.simplification_hints, "This hint already exists"
        self.simplification_hints[key] = value
Martin Bauer's avatar
Martin Bauer committed
62

Martin Bauer's avatar
Martin Bauer committed
63
64
    def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
        """Adds a subexpression to current collection.
65

Martin Bauer's avatar
Martin Bauer committed
66
67
68
69
70
        Args:
            rhs: right hand side of new subexpression
            lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
            topological_sort: sort the subexpressions topologically after insertion, to make sure that
                              definition of a symbol comes before its usage. If False, subexpression is appended.
71

Martin Bauer's avatar
Martin Bauer committed
72
73
        Returns:
            left hand side symbol (which could have been generated)
74
        """
Martin Bauer's avatar
Martin Bauer committed
75
        if lhs is None:
Martin Bauer's avatar
Martin Bauer committed
76
            lhs = next(self.subexpression_symbol_generator)
Martin Bauer's avatar
Martin Bauer committed
77
78
79
        eq = Assignment(lhs, rhs)
        self.subexpressions.append(eq)
        if topological_sort:
80
81
            self.topological_sort(sort_subexpressions=True,
                                  sort_main_assignments=False)
Martin Bauer's avatar
Martin Bauer committed
82
        return lhs
83

Martin Bauer's avatar
Martin Bauer committed
84
85
86
    def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
        """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
        if sort_subexpressions:
87
            self.subexpressions = sort_assignments_topologically(self.subexpressions)
Martin Bauer's avatar
Martin Bauer committed
88
        if sort_main_assignments:
89
            self.main_assignments = sort_assignments_topologically(self.main_assignments)
90
91
92
93

    # ---------------------------------------------- Properties  -------------------------------------------------------

    @property
Martin Bauer's avatar
Martin Bauer committed
94
95
96
    def all_assignments(self) -> List[Assignment]:
        """Subexpression and main equations as a single list."""
        return self.subexpressions + self.main_assignments
97
98

    @property
Martin Bauer's avatar
Martin Bauer committed
99
100
101
102
103
104
    def free_symbols(self) -> Set[sp.Symbol]:
        """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
        free_symbols = set()
        for eq in self.all_assignments:
            free_symbols.update(eq.rhs.atoms(sp.Symbol))
        return free_symbols - self.bound_symbols
105
106

    @property
Martin Bauer's avatar
Martin Bauer committed
107
108
109
110
    def bound_symbols(self) -> Set[sp.Symbol]:
        """All symbols which occur on the left hand side of a main assignment or a subexpression."""
        bound_symbols_set = set([eq.lhs for eq in self.all_assignments])
        assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \
111
            "Not in SSA form - same symbol assigned multiple times"
Martin Bauer's avatar
Martin Bauer committed
112
        return bound_symbols_set
113

114
    @property
115
116
117
118
119
120
121
122
123
    def free_fields(self):
        """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
        return {s.field for s in self.free_symbols if hasattr(s, 'field')}

    @property
    def bound_fields(self):
        """All field accessed on the left hand side of a main assignment or a subexpression."""
        return {s.field for s in self.bound_symbols if hasattr(s, 'field')}

124
    @property
Martin Bauer's avatar
Martin Bauer committed
125
126
127
    def defined_symbols(self) -> Set[sp.Symbol]:
        """All symbols which occur as left-hand-sides of one of the main equations"""
        return set([assignment.lhs for assignment in self.main_assignments])
128

129
    @property
Martin Bauer's avatar
Martin Bauer committed
130
131
132
133
134
135
    def operation_count(self):
        """See :func:`count_operations` """
        return count_operations(self.all_assignments, only_type=None)

    def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
        """Returns all symbols that depend on one of the passed symbols.
136

Martin Bauer's avatar
Martin Bauer committed
137
        A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
Martin Bauer's avatar
Martin Bauer committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        'b' is required to compute 'a'.
        """

        queue = list(symbols)

        def add_symbols_from_expr(expr):
            dependent_symbols = expr.atoms(sp.Symbol)
            for ds in dependent_symbols:
                queue.append(ds)

        handled_symbols = set()
        assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}

        while len(queue) > 0:
            e = queue.pop(0)
            if e in handled_symbols:
                continue
            if e in assignment_dict:
                add_symbols_from_expr(assignment_dict[e])
            handled_symbols.add(e)

        return handled_symbols

Martin Bauer's avatar
Martin Bauer committed
161
    def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
Martin Bauer's avatar
Martin Bauer committed
162
        """Returns a python function to evaluate this equation collection.
163

Martin Bauer's avatar
Martin Bauer committed
164
165
166
167
        Args:
            symbols: symbol(s) which are the parameter for the created function
            fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
            module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
168

Martin Bauer's avatar
Martin Bauer committed
169
170
171
172
173
174
175
176
177
178
179
        Examples:
              >>> a, b, c, d = sp.symbols("a b c d")
              >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
              ...                           subexpressions=[Assignment(b, a + b / 2)])
              >>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
              >>> python_function(4)
              {c: 6, d: 18}
        """
        assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
        assignments = assignments.new_without_subexpressions().main_assignments
        lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
180

Martin Bauer's avatar
Martin Bauer committed
181
182
        def f(*args, **kwargs):
            return {s: func(*args, **kwargs) for s, func in lambdas.items()}
183

Martin Bauer's avatar
Martin Bauer committed
184
185
        return f
    # ---------------------------- Creating new modified collections ---------------------------------------------------
186

Martin Bauer's avatar
Martin Bauer committed
187
188
189
190
    def copy(self,
             main_assignments: Optional[List[Assignment]] = None,
             subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
        """Returns a copy with optionally replaced main_assignments and/or subexpressions."""
191

Martin Bauer's avatar
Martin Bauer committed
192
193
194
        res = copy(self)
        res.simplification_hints = self.simplification_hints.copy()
        res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
195

Martin Bauer's avatar
Martin Bauer committed
196
197
198
199
        if main_assignments is not None:
            res.main_assignments = main_assignments
        else:
            res.main_assignments = self.main_assignments.copy()
200

Martin Bauer's avatar
Martin Bauer committed
201
202
203
204
        if subexpressions is not None:
            res.subexpressions = subexpressions
        else:
            res.subexpressions = self.subexpressions.copy()
Martin Bauer's avatar
Martin Bauer committed
205

Martin Bauer's avatar
Martin Bauer committed
206
        return res
207

Martin Bauer's avatar
Martin Bauer committed
208
    def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
209
210
                               substitute_on_lhs: bool = True,
                               sort_topologically: bool = True) -> 'AssignmentCollection':
Martin Bauer's avatar
Martin Bauer committed
211
        """Returns new object, where terms are substituted according to the passed substitution dict.
212

Martin Bauer's avatar
Martin Bauer committed
213
214
215
216
        Args:
            substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
            add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
            substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
217
218
            sort_topologically: if subexpressions are added as substitutions and this parameters is true,
                                the subexpressions are sorted topologically after insertion
Martin Bauer's avatar
Martin Bauer committed
219
220
221
        Returns:
            New AssignmentCollection where substitutions have been applied, self is not altered.
        """
222
223
224
        transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
        transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
        transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
225

Martin Bauer's avatar
Martin Bauer committed
226
        if add_substitutions_as_subexpressions:
227
228
            transformed_subexpressions = [Assignment(b, a) for a, b in
                                          substitutions.items()] + transformed_subexpressions
229
230
            if sort_topologically:
                transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
231
        return self.copy(transformed_assignments, transformed_subexpressions)
232

Martin Bauer's avatar
Martin Bauer committed
233
234
235
236
237
238
    def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
        """Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
        own_definitions = set([e.lhs for e in self.main_assignments])
        other_definitions = set([e.lhs for e in other.main_assignments])
        assert len(own_definitions.intersection(other_definitions)) == 0, \
            "Cannot new_merged, since both collection define the same symbols"
239

Martin Bauer's avatar
Martin Bauer committed
240
241
        own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
        substitution_dict = {}
242

Martin Bauer's avatar
Martin Bauer committed
243
        processed_other_subexpression_equations = []
Martin Bauer's avatar
Martin Bauer committed
244
245
246
        for other_subexpression_eq in other.subexpressions:
            if other_subexpression_eq.lhs in own_subexpression_symbols:
                if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
Martin Bauer's avatar
Martin Bauer committed
247
248
249
250
                    continue  # exact the same subexpression equation exists already
                else:
                    # different definition - a new name has to be introduced
                    new_lhs = next(self.subexpression_symbol_generator)
Martin Bauer's avatar
Martin Bauer committed
251
                    new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict))
Martin Bauer's avatar
Martin Bauer committed
252
                    processed_other_subexpression_equations.append(new_eq)
Martin Bauer's avatar
Martin Bauer committed
253
                    substitution_dict[other_subexpression_eq.lhs] = new_lhs
Martin Bauer's avatar
Martin Bauer committed
254
            else:
Martin Bauer's avatar
Martin Bauer committed
255
                processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
Martin Bauer's avatar
Martin Bauer committed
256
257
258
259

        processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
        return self.copy(self.main_assignments + processed_other_main_assignments,
                         self.subexpressions + processed_other_subexpression_equations)
260

Martin Bauer's avatar
Martin Bauer committed
261
262
    def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
        """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
263

Martin Bauer's avatar
Martin Bauer committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        Returns:
            new AssignmentCollection, self is not altered
        """
        symbols_to_extract = set(symbols_to_extract)
        dependent_symbols = self.dependent_symbols(symbols_to_extract)
        new_assignments = []
        for eq in self.all_assignments:
            if eq.lhs in symbols_to_extract:
                new_assignments.append(eq)

        new_sub_expr = [eq for eq in self.subexpressions
                        if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
        return AssignmentCollection(new_assignments, new_sub_expr)

    def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
        """Returns new collection that only contains subexpressions required to compute the main assignments."""
        all_lhs = [eq.lhs for eq in self.main_assignments]
        return self.new_filtered(all_lhs)

    def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
        """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
        new_subexpressions = []
        subs_dict = None
Martin Bauer's avatar
Martin Bauer committed
287
288
        for se in self.subexpressions:
            if se.lhs == symbol:
Martin Bauer's avatar
Martin Bauer committed
289
                subs_dict = {se.lhs: se.rhs}
Martin Bauer's avatar
Martin Bauer committed
290
            else:
Martin Bauer's avatar
Martin Bauer committed
291
292
                new_subexpressions.append(se)
        if subs_dict is None:
Martin Bauer's avatar
Martin Bauer committed
293
294
            return self

Martin Bauer's avatar
Martin Bauer committed
295
296
297
        new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
        new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
        return self.copy(new_eqs, new_subexpressions)
Martin Bauer's avatar
Martin Bauer committed
298

Martin Bauer's avatar
Martin Bauer committed
299
300
    def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection':
        """Returns a new collection where all subexpressions have been inserted."""
301
        if len(self.subexpressions) == 0:
302
303
            return self.copy()

Martin Bauer's avatar
Martin Bauer committed
304
        subexpressions_to_keep = set(subexpressions_to_keep)
305

Martin Bauer's avatar
Martin Bauer committed
306
307
308
309
        kept_subexpressions = []
        if self.subexpressions[0].lhs in subexpressions_to_keep:
            substitution_dict = {}
            kept_subexpressions = self.subexpressions[0]
310
        else:
Martin Bauer's avatar
Martin Bauer committed
311
            substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
312

Martin Bauer's avatar
Martin Bauer committed
313
314
315
316
317
        subexpression = [e for e in self.subexpressions]
        for i in range(1, len(subexpression)):
            subexpression[i] = fast_subs(subexpression[i], substitution_dict)
            if subexpression[i].lhs in subexpressions_to_keep:
                kept_subexpressions.append(subexpression[i])
318
            else:
Martin Bauer's avatar
Martin Bauer committed
319
                substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
320

Martin Bauer's avatar
Martin Bauer committed
321
322
        new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
        return self.copy(new_assignment, kept_subexpressions)
323

Martin Bauer's avatar
Martin Bauer committed
324
    # ----------------------------------------- Display and Printing   -------------------------------------------------
325

Martin Bauer's avatar
Martin Bauer committed
326
327
328
329
330
331
332
333
334
335
336
337
    def _repr_html_(self):
        """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
        def make_html_equation_table(equations):
            no_border = 'style="border:none"'
            html_table = '<table style="border:none; width: 100%; ">'
            line = '<tr {nb}> <td {nb}>$${eq}$$</td>  </tr> '
            for eq in equations:
                format_dict = {'eq': sp.latex(eq),
                               'nb': no_border, }
                html_table += line.format(**format_dict)
            html_table += "</table>"
            return html_table
338

Martin Bauer's avatar
Martin Bauer committed
339
340
341
342
343
344
345
346
347
        result = ""
        if len(self.subexpressions) > 0:
            result += "<div>Subexpressions:</div>"
            result += make_html_equation_table(self.subexpressions)
        result += "<div>Main Assignments:</div>"
        result += make_html_equation_table(self.main_assignments)
        return result

    def __repr__(self):
348
        return "Assignment Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments])
Martin Bauer's avatar
Martin Bauer committed
349
350

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
351
        result = "Subexpressions:\n"
Martin Bauer's avatar
Martin Bauer committed
352
        for eq in self.subexpressions:
353
            result += "\t{eq}\n".format(eq=eq)
Martin Bauer's avatar
Martin Bauer committed
354
        result += "Main Assignments:\n"
Martin Bauer's avatar
Martin Bauer committed
355
        for eq in self.main_assignments:
356
            result += "\t{eq}\n".format(eq=eq)
Martin Bauer's avatar
Martin Bauer committed
357
        return result
358

359
    def __iter__(self):
360
        return self.all_assignments.__iter__()
361

362
363
364
365
366
367
368
369
370
371
372
373
374
    @property
    def main_assignments_dict(self):
        return {a.lhs: a.rhs for a in self.main_assignments}

    @property
    def subexpressions_dict(self):
        return {a.lhs: a.rhs for a in self.subexpressions}

    def set_main_assignments_from_dict(self, main_assignments_dict):
        self.main_assignments = [Assignment(k, v)
                                 for k, v in main_assignments_dict.items()]

    def set_sub_expressions_from_dict(self, sub_expressions_dict):
375
376
        self.subexpressions = [Assignment(k, v)
                               for k, v in sub_expressions_dict.items()]
377

378
    def find(self, *args, **kwargs):
379
380
381
        return set.union(
            *[a.find(*args, **kwargs) for a in self.all_assignments]
        )
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

    def match(self, *args, **kwargs):
        rtn = {}
        for a in self.all_assignments:
            partial_result = a.match(*args, **kwargs)
            if partial_result:
                rtn.update(partial_result)
        return rtn

    def subs(self, *args, **kwargs):
        return AssignmentCollection(
            main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
            subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
        )

    def replace(self, *args, **kwargs):
        return AssignmentCollection(
            main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
            subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
        )

403
404
405
    def __eq__(self, other):
        return set(self.all_assignments) == set(other.all_assignments)

406
407

class SymbolGen:
Martin Bauer's avatar
Martin Bauer committed
408
    """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
409

Martin Bauer's avatar
Martin Bauer committed
410
    def __init__(self, symbol="xi"):
411
        self._ctr = 0
Martin Bauer's avatar
Martin Bauer committed
412
        self._symbol = symbol
413
414
415
416
417

    def __iter__(self):
        return self

    def __next__(self):
418
        name = "{}_{}".format(self._symbol, self._ctr)
419
        self._ctr += 1
Martin Bauer's avatar
Martin Bauer committed
420
        return sp.Symbol(name)