simplifications.py 4.22 KB
Newer Older
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
from typing import Callable, List
Martin Bauer's avatar
Martin Bauer committed
3
4
from pystencils.assignment import Assignment
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.sympyextensions import subs_additive
6
7


Martin Bauer's avatar
Martin Bauer committed
8
9
10
11
12
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
    """Searches for common subexpressions inside the equation collection.

    Searches is done in both the existing subexpressions as well as the assignments themselves.
    It uses the sympy subexpression detection to do this. Return a new equation collection
13
14
    with the additional subexpressions found
    """
Martin Bauer's avatar
Martin Bauer committed
15
16
17
18
    symbol_gen = ac.subexpression_symbol_generator
    replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
                                  symbols=symbol_gen)
    replacement_eqs = [Assignment(*r) for r in replacements]
19

Martin Bauer's avatar
Martin Bauer committed
20
21
    modified_subexpressions = new_eq[:len(ac.subexpressions)]
    modified_update_equations = new_eq[len(ac.subexpressions):]
22

Martin Bauer's avatar
Martin Bauer committed
23
24
25
    new_subexpressions = replacement_eqs + modified_subexpressions
    topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
    new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
26

Martin Bauer's avatar
Martin Bauer committed
27
    return ac.copy(modified_update_equations, new_subexpressions)
28
29


Martin Bauer's avatar
Martin Bauer committed
30
31
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
    """Extracts common subexpressions from a list of assignments."""
32
    ec = AssignmentCollection([], assignments)
Martin Bauer's avatar
Martin Bauer committed
33
34
35
    return sympy_cse(ec).all_assignments


Martin Bauer's avatar
Martin Bauer committed
36
37
def apply_to_all_assignments(assignment_collection: AssignmentCollection,
                             operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
Martin Bauer's avatar
Martin Bauer committed
38
    """Applies sympy expand operation to all equations in collection."""
Martin Bauer's avatar
Martin Bauer committed
39
    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
40
41
42
    return assignment_collection.copy(result)


Martin Bauer's avatar
Martin Bauer committed
43
44
def apply_on_all_subexpressions(ac: AssignmentCollection,
                                operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
Martin Bauer's avatar
Martin Bauer committed
45
    """Applies the given operation on all subexpressions of the AssignmentCollection."""
Martin Bauer's avatar
Martin Bauer committed
46
47
    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
    return ac.copy(ac.main_assignments, result)
48
49


Martin Bauer's avatar
Martin Bauer committed
50
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
Martin Bauer's avatar
Martin Bauer committed
51
    """Goes through the subexpressions list and replaces the term in the following subexpressions."""
52
    result = []
Martin Bauer's avatar
Martin Bauer committed
53
54
    for outerCtr, s in enumerate(ac.subexpressions):
        new_rhs = s.rhs
55
        for innerCtr in range(outerCtr):
Martin Bauer's avatar
Martin Bauer committed
56
57
58
59
            sub_expr = ac.subexpressions[innerCtr]
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
            new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
        result.append(Assignment(s.lhs, new_rhs))
60

Martin Bauer's avatar
Martin Bauer committed
61
    return ac.copy(ac.main_assignments, result)
62
63


Martin Bauer's avatar
Martin Bauer committed
64
65
def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection:
    """Replaces already existing subexpressions in the equations of the assignment_collection."""
66
    result = []
Martin Bauer's avatar
Martin Bauer committed
67
68
69
70
71
72
    for s in ac.main_assignments:
        new_rhs = s.rhs
        for subExpr in ac.subexpressions:
            new_rhs = subs_additive(new_rhs, subExpr.lhs, subExpr.rhs, required_match_replacement=1.0)
        result.append(Assignment(s.lhs, new_rhs))
    return ac.copy(result)
73
74


Martin Bauer's avatar
Martin Bauer committed
75
def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection:
76
    """Introduces subexpressions for all divisions which have no constant in the denominator.
Martin Bauer's avatar
Martin Bauer committed
77
78
79

    For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
    """
80
81
    divisors = set()

Martin Bauer's avatar
Martin Bauer committed
82
    def search_divisors(term):
83
84
85
86
87
        if term.func == sp.Pow:
            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
                divisors.add(term)
        else:
            for a in term.args:
Martin Bauer's avatar
Martin Bauer committed
88
                search_divisors(a)
89

Martin Bauer's avatar
Martin Bauer committed
90
91
    for eq in ac.all_assignments:
        search_divisors(eq.rhs)
92

Martin Bauer's avatar
Martin Bauer committed
93
94
95
    new_symbol_gen = ac.subexpression_symbol_generator
    substitutions = {divisor: newSymbol for newSymbol, divisor in zip(new_symbol_gen, divisors)}
    return ac.new_with_substitutions(substitutions, True)