simplifications.py 7.57 KB
Newer Older
1
2
from itertools import chain
from typing import Callable, List, Sequence, Union
Martin Bauer's avatar
Martin Bauer committed
3

Martin Bauer's avatar
Martin Bauer committed
4
5
import sympy as sp

Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.assignment import Assignment
7
from pystencils.astnodes import Node
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.field import AbstractField, Field
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.sympyextensions import subs_additive
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
    edges = []
    for c1, e1 in enumerate(assignments):
        if isinstance(e1, Assignment):
            symbols = [e1.lhs]
        elif isinstance(e1, Node):
            symbols = e1.symbols_defined
        else:
            symbols = []
        for lhs in symbols:
            for c2, e2 in enumerate(assignments):
                if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
                    edges.append((c1, c2))
                elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
                    edges.append((c1, c2))
    return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
29

30

31
def sympy_cse(ac):
Martin Bauer's avatar
Martin Bauer committed
32
33
34
35
    """Searches for common subexpressions inside the equation collection.

    Searches is done in both the existing subexpressions as well as the assignments themselves.
    It uses the sympy subexpression detection to do this. Return a new equation collection
36
37
    with the additional subexpressions found
    """
Martin Bauer's avatar
Martin Bauer committed
38
    symbol_gen = ac.subexpression_symbol_generator
39
40
41
42
43

    all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
    other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
    replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen)

Martin Bauer's avatar
Martin Bauer committed
44
    replacement_eqs = [Assignment(*r) for r in replacements]
45

Martin Bauer's avatar
Martin Bauer committed
46
47
    modified_subexpressions = new_eq[:len(ac.subexpressions)]
    modified_update_equations = new_eq[len(ac.subexpressions):]
48

49
    new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
Martin Bauer's avatar
Martin Bauer committed
50
    return ac.copy(modified_update_equations, new_subexpressions)
51
52


Martin Bauer's avatar
Martin Bauer committed
53
54
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
    """Extracts common subexpressions from a list of assignments."""
55
56
    from pystencils.simp.assignment_collection import AssignmentCollection
    ec = AssignmentCollection([], assignments)
Martin Bauer's avatar
Martin Bauer committed
57
58
59
    return sympy_cse(ec).all_assignments


60
def subexpression_substitution_in_existing_subexpressions(ac):
Martin Bauer's avatar
Martin Bauer committed
61
    """Goes through the subexpressions list and replaces the term in the following subexpressions."""
62
    result = []
Martin Bauer's avatar
Martin Bauer committed
63
    for outer_ctr, s in enumerate(ac.subexpressions):
Martin Bauer's avatar
Martin Bauer committed
64
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
65
66
        for inner_ctr in range(outer_ctr):
            sub_expr = ac.subexpressions[inner_ctr]
Martin Bauer's avatar
Martin Bauer committed
67
68
69
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
            new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
        result.append(Assignment(s.lhs, new_rhs))
70

Martin Bauer's avatar
Martin Bauer committed
71
    return ac.copy(ac.main_assignments, result)
72
73


74
def subexpression_substitution_in_main_assignments(ac):
Martin Bauer's avatar
Martin Bauer committed
75
    """Replaces already existing subexpressions in the equations of the assignment_collection."""
76
    result = []
Martin Bauer's avatar
Martin Bauer committed
77
78
    for s in ac.main_assignments:
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
79
80
        for sub_expr in ac.subexpressions:
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
Martin Bauer's avatar
Martin Bauer committed
81
82
        result.append(Assignment(s.lhs, new_rhs))
    return ac.copy(result)
83
84


85
def add_subexpressions_for_divisions(ac):
86
    r"""Introduces subexpressions for all divisions which have no constant in the denominator.
Martin Bauer's avatar
Martin Bauer committed
87

88
    For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
Martin Bauer's avatar
Martin Bauer committed
89
    """
90
91
    divisors = set()

Martin Bauer's avatar
Martin Bauer committed
92
    def search_divisors(term):
93
94
95
96
97
        if term.func == sp.Pow:
            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
                divisors.add(term)
        else:
            for a in term.args:
Martin Bauer's avatar
Martin Bauer committed
98
                search_divisors(a)
99

Martin Bauer's avatar
Martin Bauer committed
100
101
    for eq in ac.all_assignments:
        search_divisors(eq.rhs)
102

103
    divisors = sorted(list(divisors), key=lambda x: str(x))
Martin Bauer's avatar
Martin Bauer committed
104
    new_symbol_gen = ac.subexpression_symbol_generator
Martin Bauer's avatar
Martin Bauer committed
105
    substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
Martin Bauer's avatar
Martin Bauer committed
106
    return ac.new_with_substitutions(substitutions, True)
107
108


109
def add_subexpressions_for_sums(ac):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
    addends = []

    def contains_sum(term):
        if term.func == sp.add.Add:
            return True
        if term.is_Atom:
            return False
        return any([contains_sum(a) for a in term.args])

    def search_addends(term):
        if term.func == sp.add.Add:
            if all([not contains_sum(a) for a in term.args]):
                addends.extend(term.args)
        for a in term.args:
            search_addends(a)

    for eq in ac.all_assignments:
        search_addends(eq.rhs)

    addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)]
    new_symbol_gen = ac.subexpression_symbol_generator
    substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
    return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)


136
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):
Martin Bauer's avatar
Martin Bauer committed
137
138
139
140
141
142
143
144
145
146
147
148
149
    r"""Substitutes field accesses on rhs of assignments with subexpressions

    Can change semantics of the update rule (which is the goal of this transformation)
    This is useful if a field should be update in place - all values are loaded before into subexpression variables,
    then the new values are computed and written to the same field in-place.
    """
    field_reads = set()
    if subexpressions:
        for assignment in ac.subexpressions:
            field_reads.update(assignment.rhs.atoms(Field.Access))
    if main_assignments:
        for assignment in ac.main_assignments:
            field_reads.update(assignment.rhs.atoms(Field.Access))
150
    substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
Martin Bauer's avatar
Martin Bauer committed
151
152
153
    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def transform_rhs(assignment_list, transformation, *args, **kwargs):
    """Applies a transformation function on the rhs of each element of the passed assignment list
    If the list also contains other object, like AST nodes, these are ignored.
    Additional parameters are passed to the transformation function"""
    return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if isinstance(a, Assignment) else a
            for a in assignment_list]


def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
    return [Assignment(transformation(a.lhs, *args, **kwargs),
                       transformation(a.rhs, *args, **kwargs))
            if isinstance(a, Assignment) else a
            for a in assignment_list]


def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
170
    """Applies sympy expand operation to all equations in collection."""
171
172

    def f(ac):
173
        return ac.copy(transform_rhs(ac.main_assignments, operation))
174

175
176
177
178
    f.__name__ = operation.__name__
    return f


179
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
180
    """Applies the given operation on all subexpressions of the AC."""
181
182

    def f(ac):
183
        return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
184

185
    f.__name__ = operation.__name__
Martin Bauer's avatar
Martin Bauer committed
186
    return f