Skip to content
Snippets Groups Projects
test_simplifications.py 3.31 KiB
Newer Older
import sympy as sp

from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads
from pystencils import Assignment, AssignmentCollection, fields

a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
f = sp.symbols("f_:9")


def test_subexpression_substitution_in_main_assignments():
    subexpressions = [
        Assignment(s0, 2 * a + 2 * b),
        Assignment(s1, 2 * a + 2 * b + 2 * c),
        Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
        Assignment(s3, 2 * a + 2 * b * c),
        Assignment(x, s1 + s2 + s0 + s3)
    ]
    main = [
        Assignment(f[0], s1 + s2 + s0 + s3),
        Assignment(f[1], s1 + s2 + s0 + s3),
        Assignment(f[2], s1 + s2 + s0 + s3),
        Assignment(f[3], s1 + s2 + s0 + s3),
        Assignment(f[4], s1 + s2 + s0 + s3)
    ]
    ac = AssignmentCollection(main, subexpressions)
    ac = subexpression_substitution_in_main_assignments(ac)
    for i in range(0, len(ac.main_assignments)):
        assert ac.main_assignments[i].rhs == x


def test_add_subexpressions_for_divisions():
    subexpressions = [
        Assignment(s0, 2 / a + 2 / b),
        Assignment(s1, 2 / a + 2 / b + 2 / c),
        Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
        Assignment(s3, 2 / a + 2 / b / c),
        Assignment(x, s1 + s2 + s0 + s3)
    ]
    main = [
        Assignment(f[0], s1 + s2 + s0 + s3)
    ]
    ac = AssignmentCollection(main, subexpressions)
    divs_before_optimisation = ac.operation_count["divs"]
    ac = add_subexpressions_for_divisions(ac)
    divs_after_optimisation = ac.operation_count["divs"]
    assert divs_before_optimisation - divs_after_optimisation == 8
    rhs = []
    for i in range(len(ac.subexpressions)):
        rhs.append(ac.subexpressions[i].rhs)

    assert 1/a in rhs
    assert 1/b in rhs
    assert 1/c in rhs
    assert 1/d in rhs


def test_add_subexpressions_for_sums():
    subexpressions = [
        Assignment(s0, a + b + c + d),
        Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
        Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
        Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
    ]
    main = [
        Assignment(f[0], s1 + s2 + s0 + s3)
    ]
    ac = AssignmentCollection(main, subexpressions)
    ops_before_optimisation = ac.operation_count
    ac = add_subexpressions_for_sums(ac)
    ops_after_optimisation = ac.operation_count
    assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
    assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
    assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]

    rhs = []
    for i in range(len(ac.subexpressions)):
        rhs.append(ac.subexpressions[i].rhs)

    assert a + b + c + d in rhs
    assert 3 * a * sp.sqrt(x) in rhs


def test_add_subexpressions_for_field_reads():
    s, v = fields("s(5), v(5): double[2D]")
    subexpressions = []
    main = [
        Assignment(s[0, 0](0), 3 * v[0, 0](0)),
        Assignment(s[0, 0](1), 10 * v[0, 0](1))
    ]
    ac = AssignmentCollection(main, subexpressions)
    assert len(ac.subexpressions) == 0
    ac = add_subexpressions_for_field_reads(ac)
    assert len(ac.subexpressions) == 2