simplifications.py 6.02 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
from typing import Callable, List
Martin Bauer's avatar
Martin Bauer committed
2

Martin Bauer's avatar
Martin Bauer committed
3 4
import sympy as sp

Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.field import AbstractField, Field
7
from pystencils.simp.assignment_collection import AssignmentCollection, transform_rhs
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.sympyextensions import subs_additive
9

10
AC = AssignmentCollection
11

12 13

def sympy_cse(ac: AC) -> AC:
Martin Bauer's avatar
Martin Bauer committed
14 15 16 17
    """Searches for common subexpressions inside the equation collection.

    Searches is done in both the existing subexpressions as well as the assignments themselves.
    It uses the sympy subexpression detection to do this. Return a new equation collection
18 19
    with the additional subexpressions found
    """
Martin Bauer's avatar
Martin Bauer committed
20 21 22 23
    symbol_gen = ac.subexpression_symbol_generator
    replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
                                  symbols=symbol_gen)
    replacement_eqs = [Assignment(*r) for r in replacements]
24

Martin Bauer's avatar
Martin Bauer committed
25 26
    modified_subexpressions = new_eq[:len(ac.subexpressions)]
    modified_update_equations = new_eq[len(ac.subexpressions):]
27

Martin Bauer's avatar
Martin Bauer committed
28 29 30
    new_subexpressions = replacement_eqs + modified_subexpressions
    topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
    new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
31

Martin Bauer's avatar
Martin Bauer committed
32
    return ac.copy(modified_update_equations, new_subexpressions)
33 34


Martin Bauer's avatar
Martin Bauer committed
35 36
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
    """Extracts common subexpressions from a list of assignments."""
37
    ec = AC([], assignments)
Martin Bauer's avatar
Martin Bauer committed
38 39 40
    return sympy_cse(ec).all_assignments


41
def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
Martin Bauer's avatar
Martin Bauer committed
42
    """Goes through the subexpressions list and replaces the term in the following subexpressions."""
43
    result = []
Martin Bauer's avatar
Martin Bauer committed
44
    for outer_ctr, s in enumerate(ac.subexpressions):
Martin Bauer's avatar
Martin Bauer committed
45
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
46 47
        for inner_ctr in range(outer_ctr):
            sub_expr = ac.subexpressions[inner_ctr]
Martin Bauer's avatar
Martin Bauer committed
48 49 50
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
            new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
        result.append(Assignment(s.lhs, new_rhs))
51

Martin Bauer's avatar
Martin Bauer committed
52
    return ac.copy(ac.main_assignments, result)
53 54


55
def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
Martin Bauer's avatar
Martin Bauer committed
56
    """Replaces already existing subexpressions in the equations of the assignment_collection."""
57
    result = []
Martin Bauer's avatar
Martin Bauer committed
58 59
    for s in ac.main_assignments:
        new_rhs = s.rhs
Martin Bauer's avatar
Martin Bauer committed
60 61
        for sub_expr in ac.subexpressions:
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
Martin Bauer's avatar
Martin Bauer committed
62 63
        result.append(Assignment(s.lhs, new_rhs))
    return ac.copy(result)
64 65


66
def add_subexpressions_for_divisions(ac: AC) -> AC:
67
    r"""Introduces subexpressions for all divisions which have no constant in the denominator.
Martin Bauer's avatar
Martin Bauer committed
68

69
    For example :math:`\frac{1}{x}` is replaced while :math:`\frac{1}{3}` is not replaced.
Martin Bauer's avatar
Martin Bauer committed
70
    """
71 72
    divisors = set()

Martin Bauer's avatar
Martin Bauer committed
73
    def search_divisors(term):
74 75 76 77 78
        if term.func == sp.Pow:
            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
                divisors.add(term)
        else:
            for a in term.args:
Martin Bauer's avatar
Martin Bauer committed
79
                search_divisors(a)
80

Martin Bauer's avatar
Martin Bauer committed
81 82
    for eq in ac.all_assignments:
        search_divisors(eq.rhs)
83

84
    divisors = sorted(list(divisors), key=lambda x: str(x))
Martin Bauer's avatar
Martin Bauer committed
85
    new_symbol_gen = ac.subexpression_symbol_generator
Martin Bauer's avatar
Martin Bauer committed
86
    substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
Martin Bauer's avatar
Martin Bauer committed
87
    return ac.new_with_substitutions(substitutions, True)
88 89


90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
def add_subexpressions_for_sums(ac: AC) -> AC:
    r"""Introduces subexpressions for all sums - i.e. splits addends into subexpressions."""
    addends = []

    def contains_sum(term):
        if term.func == sp.add.Add:
            return True
        if term.is_Atom:
            return False
        return any([contains_sum(a) for a in term.args])

    def search_addends(term):
        if term.func == sp.add.Add:
            if all([not contains_sum(a) for a in term.args]):
                addends.extend(term.args)
        for a in term.args:
            search_addends(a)

    for eq in ac.all_assignments:
        search_addends(eq.rhs)

    addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)]
    new_symbol_gen = ac.subexpression_symbol_generator
    substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
    return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)


Martin Bauer's avatar
Martin Bauer committed
117 118 119 120 121 122 123 124 125 126 127 128 129 130
def add_subexpressions_for_field_reads(ac: AC, subexpressions=True, main_assignments=True) -> AC:
    r"""Substitutes field accesses on rhs of assignments with subexpressions

    Can change semantics of the update rule (which is the goal of this transformation)
    This is useful if a field should be update in place - all values are loaded before into subexpression variables,
    then the new values are computed and written to the same field in-place.
    """
    field_reads = set()
    if subexpressions:
        for assignment in ac.subexpressions:
            field_reads.update(assignment.rhs.atoms(Field.Access))
    if main_assignments:
        for assignment in ac.main_assignments:
            field_reads.update(assignment.rhs.atoms(Field.Access))
131
    substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
Martin Bauer's avatar
Martin Bauer committed
132 133 134
    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False)


135 136
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
    """Applies sympy expand operation to all equations in collection."""
137 138
    def f(ac: AC) -> AC:
        return ac.copy(transform_rhs(ac.main_assignments, operation))
139 140 141 142 143 144 145
    f.__name__ = operation.__name__
    return f


def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
    """Applies the given operation on all subexpressions of the AC."""
    def f(ac: AC) -> AC:
146
        return ac.copy(ac.main_assignments, transform_rhs(ac.subexpressions, operation))
147
    f.__name__ = operation.__name__
Martin Bauer's avatar
Martin Bauer committed
148
    return f