simplifications.py 5.07 KB
Newer Older
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
from typing import Callable, List
Martin Bauer's avatar
Martin Bauer committed
3
4

from pystencils import Field
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.assignment import Assignment
6
from pystencils.simp.assignment_collection import AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
7
from pystencils.sympyextensions import subs_additive
8

9
AC = AssignmentCollection
10

11
12

def sympy_cse(ac: AC) -> AC:
Martin Bauer's avatar
Martin Bauer committed
13
14
15
16
    """Searches for common subexpressions inside the equation collection.

    Searches is done in both the existing subexpressions as well as the assignments themselves.
    It uses the sympy subexpression detection to do this. Return a new equation collection
17
18
    with the additional subexpressions found
    """
Martin Bauer's avatar
Martin Bauer committed
19
20
21
22
    symbol_gen = ac.subexpression_symbol_generator
    replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
                                  symbols=symbol_gen)
    replacement_eqs = [Assignment(*r) for r in replacements]
23

Martin Bauer's avatar
Martin Bauer committed
24
25
    modified_subexpressions = new_eq[:len(ac.subexpressions)]
    modified_update_equations = new_eq[len(ac.subexpressions):]
26

Martin Bauer's avatar
Martin Bauer committed
27
28
29
    new_subexpressions = replacement_eqs + modified_subexpressions
    topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
    new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
30

Martin Bauer's avatar
Martin Bauer committed
31
    return ac.copy(modified_update_equations, new_subexpressions)
32
33


Martin Bauer's avatar
Martin Bauer committed
34
35
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
    """Extracts common subexpressions from a list of assignments."""
36
    ec = AC([], assignments)
Martin Bauer's avatar
Martin Bauer committed
37
38
39
    return sympy_cse(ec).all_assignments


40
def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
Martin Bauer's avatar
Martin Bauer committed
41
    """Goes through the subexpressions list and replaces the term in the following subexpressions."""
42
    result = []
Martin Bauer's avatar
Martin Bauer committed
43
    for outer_ctr, s in enumerate(ac.subexpressions):
Martin Bauer's avatar
Martin Bauer committed
44
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
45
46
        for inner_ctr in range(outer_ctr):
            sub_expr = ac.subexpressions[inner_ctr]
Martin Bauer's avatar
Martin Bauer committed
47
48
49
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
            new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
        result.append(Assignment(s.lhs, new_rhs))
50

Martin Bauer's avatar
Martin Bauer committed
51
    return ac.copy(ac.main_assignments, result)
52
53


54
def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
Martin Bauer's avatar
Martin Bauer committed
55
    """Replaces already existing subexpressions in the equations of the assignment_collection."""
56
    result = []
Martin Bauer's avatar
Martin Bauer committed
57
58
    for s in ac.main_assignments:
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
59
60
        for sub_expr in ac.subexpressions:
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
Martin Bauer's avatar
Martin Bauer committed
61
62
        result.append(Assignment(s.lhs, new_rhs))
    return ac.copy(result)
63
64


65
def add_subexpressions_for_divisions(ac: AC) -> AC:
66
    r"""Introduces subexpressions for all divisions which have no constant in the denominator.
Martin Bauer's avatar
Martin Bauer committed
67

68
    For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
Martin Bauer's avatar
Martin Bauer committed
69
    """
70
71
    divisors = set()

Martin Bauer's avatar
Martin Bauer committed
72
    def search_divisors(term):
73
74
75
76
77
        if term.func == sp.Pow:
            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
                divisors.add(term)
        else:
            for a in term.args:
Martin Bauer's avatar
Martin Bauer committed
78
                search_divisors(a)
79

Martin Bauer's avatar
Martin Bauer committed
80
81
    for eq in ac.all_assignments:
        search_divisors(eq.rhs)
82

Martin Bauer's avatar
Martin Bauer committed
83
    new_symbol_gen = ac.subexpression_symbol_generator
Martin Bauer's avatar
Martin Bauer committed
84
    substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
Martin Bauer's avatar
Martin Bauer committed
85
    return ac.new_with_substitutions(substitutions, True)
86
87


Martin Bauer's avatar
Martin Bauer committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC:
    r"""Substitutes field accesses on rhs of assignments with subexpressions

    Can change semantics of the update rule (which is the goal of this transformation)
    This is useful if a field should be update in place - all values are loaded before into subexpression variables,
    then the new values are computed and written to the same field in-place.
    """
    field_reads = set()
    if subexpressions:
        for assignment in ac.subexpressions:
            field_reads.update(assignment.rhs.atoms(Field.Access))
    if main_assignments:
        for assignment in ac.main_assignments:
            field_reads.update(assignment.rhs.atoms(Field.Access))
    substitutions = {fa: sp.Dummy() for fa in field_reads}
    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)


106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
    """Applies sympy expand operation to all equations in collection."""
    def f(assignment_collection: AC) -> AC:
        result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
        return assignment_collection.copy(result)
    f.__name__ = operation.__name__
    return f


def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
    """Applies the given operation on all subexpressions of the AC."""
    def f(ac: AC) -> AC:
        result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
        return ac.copy(ac.main_assignments, result)
    f.__name__ = operation.__name__
Martin Bauer's avatar
Martin Bauer committed
121
    return f