simplifications.py 4.06 KB
Newer Older
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
from typing import Callable, List
3
from pystencils import Assignment, AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.sympyextensions import subs_additive
5
6


Martin Bauer's avatar
Martin Bauer committed
7
8
9
10
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
    """Extracts common subexpressions from a list of assignments."""
    ec = AssignmentCollection(assignments, [])
    return sympy_cse(ec).all_assignments
11
12


Martin Bauer's avatar
Martin Bauer committed
13
14
15
16
17
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
    """Searches for common subexpressions inside the equation collection.

    Searches is done in both the existing subexpressions as well as the assignments themselves.
    It uses the sympy subexpression detection to do this. Return a new equation collection
18
19
    with the additional subexpressions found
    """
Martin Bauer's avatar
Martin Bauer committed
20
21
22
23
    symbol_gen = ac.subexpression_symbol_generator
    replacements, new_eq = sp.cse(ac.subexpressions + ac.main_assignments,
                                  symbols=symbol_gen)
    replacement_eqs = [Assignment(*r) for r in replacements]
24

Martin Bauer's avatar
Martin Bauer committed
25
26
    modified_subexpressions = new_eq[:len(ac.subexpressions)]
    modified_update_equations = new_eq[len(ac.subexpressions):]
27

Martin Bauer's avatar
Martin Bauer committed
28
29
30
    new_subexpressions = replacement_eqs + modified_subexpressions
    topologically_sorted_pairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in new_subexpressions])
    new_subexpressions = [Assignment(a[0], a[1]) for a in topologically_sorted_pairs]
31

Martin Bauer's avatar
Martin Bauer committed
32
    return ac.copy(modified_update_equations, new_subexpressions)
33
34


Martin Bauer's avatar
Martin Bauer committed
35
36
def apply_to_all_assignments(assignment_collection: AssignmentCollection,
                             operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
37
    """Applies sympy expand operation to all equations in collection"""
Martin Bauer's avatar
Martin Bauer committed
38
    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
39
40
41
    return assignment_collection.copy(result)


Martin Bauer's avatar
Martin Bauer committed
42
43
44
45
def apply_on_all_subexpressions(ac: AssignmentCollection,
                                operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
    return ac.copy(ac.main_assignments, result)
46
47


Martin Bauer's avatar
Martin Bauer committed
48
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
49
50
    """Goes through the subexpressions list and replaces the term in the following subexpressions"""
    result = []
Martin Bauer's avatar
Martin Bauer committed
51
52
    for outerCtr, s in enumerate(ac.subexpressions):
        new_rhs = s.rhs
53
        for innerCtr in range(outerCtr):
Martin Bauer's avatar
Martin Bauer committed
54
55
56
57
            sub_expr = ac.subexpressions[innerCtr]
            new_rhs = subs_additive(new_rhs, sub_expr.lhs, sub_expr.rhs, required_match_replacement=1.0)
            new_rhs = new_rhs.subs(sub_expr.rhs, sub_expr.lhs)
        result.append(Assignment(s.lhs, new_rhs))
58

Martin Bauer's avatar
Martin Bauer committed
59
    return ac.copy(ac.main_assignments, result)
60
61


Martin Bauer's avatar
Martin Bauer committed
62
63
def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection:
    """Replaces already existing subexpressions in the equations of the assignment_collection."""
64
    result = []
Martin Bauer's avatar
Martin Bauer committed
65
66
67
68
69
70
    for s in ac.main_assignments:
        new_rhs = s.rhs
        for subExpr in ac.subexpressions:
            new_rhs = subs_additive(new_rhs, subExpr.lhs, subExpr.rhs, required_match_replacement=1.0)
        result.append(Assignment(s.lhs, new_rhs))
    return ac.copy(result)
71
72


Martin Bauer's avatar
Martin Bauer committed
73
def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection:
74
    """Introduces subexpressions for all divisions which have no constant in the denominator.
Martin Bauer's avatar
Martin Bauer committed
75
76
77

    For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
    """
78
79
    divisors = set()

Martin Bauer's avatar
Martin Bauer committed
80
    def search_divisors(term):
81
82
83
84
85
        if term.func == sp.Pow:
            if term.exp.is_integer and term.exp.is_number and term.exp < 0:
                divisors.add(term)
        else:
            for a in term.args:
Martin Bauer's avatar
Martin Bauer committed
86
                search_divisors(a)
87

Martin Bauer's avatar
Martin Bauer committed
88
89
    for eq in ac.all_assignments:
        search_divisors(eq.rhs)
90

Martin Bauer's avatar
Martin Bauer committed
91
92
93
    new_symbol_gen = ac.subexpression_symbol_generator
    substitutions = {divisor: newSymbol for newSymbol, divisor in zip(new_symbol_gen, divisors)}
    return ac.new_with_substitutions(substitutions, True)