simplifications.py 7.92 KB
Newer Older
1 2
from itertools import chain
from typing import Callable, List, Sequence, Union
Martin Bauer's avatar
Martin Bauer committed
3

Martin Bauer's avatar
Martin Bauer committed
4 5
import sympy as sp

Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.assignment import Assignment
7
from pystencils.astnodes import Node
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.field import AbstractField, Field
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.sympyextensions import subs_additive
10

11 12 13 14 15

def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
    """Sorts assignments in topological order, such that symbols used on rhs occur first on a lhs"""
    edges = []
    for c1, e1 in enumerate(assignments):
16
        if hasattr(e1, 'lhs') and hasattr(e1, 'rhs'):
17 18 19 20
            symbols = [e1.lhs]
        elif isinstance(e1, Node):
            symbols = e1.symbols_defined
        else:
21
            raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
22

23 24 25 26 27 28 29
        for lhs in symbols:
            for c2, e2 in enumerate(assignments):
                if isinstance(e2, Assignment) and lhs in e2.rhs.free_symbols:
                    edges.append((c1, c2))
                elif isinstance(e2, Node) and lhs in e2.undefined_symbols:
                    edges.append((c1, c2))
    return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
30

31

32
def sympy_cse(ac, **kwargs):
33
    """Searches for common subexpressions inside the assignment collection.
Martin Bauer's avatar
Martin Bauer committed
34 35

    Searches is done in both the existing subexpressions as well as the assignments themselves.
36
    It uses the sympy subexpression detection to do this. Return a new assignment collection
37 38
    with the additional subexpressions found
    """
Martin Bauer's avatar
Martin Bauer committed
39
    symbol_gen = ac.subexpression_symbol_generator
40 41 42

    all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
    other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
43
    replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)
44

Martin Bauer's avatar
Martin Bauer committed
45
    replacement_eqs = [Assignment(*r) for r in replacements]
46

Martin Bauer's avatar
Martin Bauer committed
47 48
    modified_subexpressions = new_eq[:len(ac.subexpressions)]
    modified_update_equations = new_eq[len(ac.subexpressions):]
49

50
    new_subexpressions = sort_assignments_topologically(other_objects + replacement_eqs + modified_subexpressions)
Martin Bauer's avatar
Martin Bauer committed
51
    return ac.copy(modified_update_equations, new_subexpressions)
52 53


Martin Bauer's avatar
Martin Bauer committed
54 55
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
    """Extracts common subexpressions from a list of assignments."""
56 57
    from pystencils.simp.assignment_collection import AssignmentCollection
    ec = AssignmentCollection([], assignments)
Martin Bauer's avatar
Martin Bauer committed
58 59 60
    return sympy_cse(ec).all_assignments


61
def subexpression_substitution_in_existing_subexpressions(ac):
Martin Bauer's avatar
Martin Bauer committed
62
    """Goes through the subexpressions list and replaces the term in the following subexpressions."""
63
    result = []
Martin Bauer's avatar
Martin Bauer committed
64
    for outer_ctr, s in enumerate(ac.subexpressions):
Martin Bauer's avatar
Martin Bauer committed
65
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
66 67
        for inner_ctr in range(outer_ctr):
            sub_expr = ac.subexpressions[inner_ctr]
Martin Bauer's avatar
Martin Bauer committed
68 69 70
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
            new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
        result.append(Assignment(s.lhs, new_rhs))
71

Martin Bauer's avatar
Martin Bauer committed
72
    return ac.copy(ac.main_assignments, result)
73 74


75
def subexpression_substitution_in_main_assignments(ac):
Martin Bauer's avatar
Martin Bauer committed
76
    """Replaces already existing subexpressions in the equations of the assignment_collection."""
77
    result = []
Martin Bauer's avatar
Martin Bauer committed
78 79
    for s in ac.main_assignments:
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
80 81
        for sub_expr in ac.subexpressions:
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
Martin Bauer's avatar
Martin Bauer committed
82 83
        result.append(Assignment(s.lhs, new_rhs))
    return ac.copy(result)
84 85


86
def add_subexpressions_for_divisions(ac):
87
    r"""Introduces subexpressions for all divisions which have no constant in the denominator.
Martin Bauer's avatar
Martin Bauer committed
88

89
    For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
Martin Bauer's avatar
Martin Bauer committed
90
    """
91 92
    divisors = set()

Martin Bauer's avatar
Martin Bauer committed
93
    def search_divisors(term):
94 95 96 97 98
        if term.func == sp.Pow:
            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
                divisors.add(term)
        else:
            for a in term.args:
Martin Bauer's avatar
Martin Bauer committed
99
                search_divisors(a)
100

Martin Bauer's avatar
Martin Bauer committed
101 102
    for eq in ac.all_assignments:
        search_divisors(eq.rhs)
103

104
    divisors = sorted(list(divisors), key=lambda x: str(x))
Martin Bauer's avatar
Martin Bauer committed
105
    new_symbol_gen = ac.subexpression_symbol_generator
Martin Bauer's avatar
Martin Bauer committed
106
    substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
107
    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)
108 109


110
def add_subexpressions_for_sums(ac):
111 112 113 114
    r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
    addends = []

    def contains_sum(term):
115
        if term.func == sp.Add:
116 117 118 119 120 121
            return True
        if term.is_Atom:
            return False
        return any([contains_sum(a) for a in term.args])

    def search_addends(term):
122
        if term.func == sp.Add:
123 124 125 126 127 128 129 130 131 132 133 134 135 136
            if all([not contains_sum(a) for a in term.args]):
                addends.extend(term.args)
        for a in term.args:
            search_addends(a)

    for eq in ac.all_assignments:
        search_addends(eq.rhs)

    addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)]
    new_symbol_gen = ac.subexpression_symbol_generator
    substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
    return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)


137
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):
Martin Bauer's avatar
Martin Bauer committed
138 139 140 141 142 143 144
    r"""Substitutes field accesses on rhs of assignments with subexpressions

    Can change semantics of the update rule (which is the goal of this transformation)
    This is useful if a field should be update in place - all values are loaded before into subexpression variables,
    then the new values are computed and written to the same field in-place.
    """
    field_reads = set()
145
    to_iterate = []
Martin Bauer's avatar
Martin Bauer committed
146
    if subexpressions:
147
        to_iterate = chain(to_iterate, ac.subexpressions)
Martin Bauer's avatar
Martin Bauer committed
148
    if main_assignments:
149 150 151 152
        to_iterate = chain(to_iterate, ac.main_assignments)

    for assignment in to_iterate:
        if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
Martin Bauer's avatar
Martin Bauer committed
153
            field_reads.update(assignment.rhs.atoms(Field.Access))
154
    substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
155 156
    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
                                     substitute_on_lhs=False, sort_topologically=False)
Martin Bauer's avatar
Martin Bauer committed
157 158


159 160 161 162
def transform_rhs(assignment_list, transformation, *args, **kwargs):
    """Applies a transformation function on the rhs of each element of the passed assignment list
    If the list also contains other object, like AST nodes, these are ignored.
    Additional parameters are passed to the transformation function"""
163
    return [Assignment(a.lhs, transformation(a.rhs, *args, **kwargs)) if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
164 165 166 167 168 169
            for a in assignment_list]


def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
    return [Assignment(transformation(a.lhs, *args, **kwargs),
                       transformation(a.rhs, *args, **kwargs))
170
            if hasattr(a, 'lhs') and hasattr(a, 'rhs') else a
171 172 173 174
            for a in assignment_list]


def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
175
    """Applies sympy expand operation to all equations in collection."""
176 177

    def f(ac):
178
        return ac.copy(transform_rhs(ac.main_assignments, operation))
179

180 181 182 183
    f.__name__ = operation.__name__
    return f


184
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
185
    """Applies the given operation on all subexpressions of the AC."""
186 187

    def f(ac):
188
        return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
189

190
    f.__name__ = operation.__name__
Martin Bauer's avatar
Martin Bauer committed
191
    return f