Skip to content
Snippets Groups Projects
test_simplifications.py 6.95 KiB
Newer Older
from sys import version_info as vs
Markus Holzer's avatar
Markus Holzer committed

import pystencils.config
import sympy as sp
import pystencils as ps
from pystencils import Assignment, AssignmentCollection, fields
from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads
from pystencils.simp.simplifications import add_subexpressions_for_constants
from pystencils.typing import BasicType, TypedSymbol

a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
f = sp.symbols("f_:9")


def test_subexpression_substitution_in_main_assignments():
    subexpressions = [
        Assignment(s0, 2 * a + 2 * b),
        Assignment(s1, 2 * a + 2 * b + 2 * c),
        Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
        Assignment(s3, 2 * a + 2 * b * c),
        Assignment(x, s1 + s2 + s0 + s3)
    ]
    main = [
        Assignment(f[0], s1 + s2 + s0 + s3),
        Assignment(f[1], s1 + s2 + s0 + s3),
        Assignment(f[2], s1 + s2 + s0 + s3),
        Assignment(f[3], s1 + s2 + s0 + s3),
        Assignment(f[4], s1 + s2 + s0 + s3)
    ]
    ac = AssignmentCollection(main, subexpressions)
    ac = subexpression_substitution_in_main_assignments(ac)
    for i in range(0, len(ac.main_assignments)):
        assert ac.main_assignments[i].rhs == x


def test_add_subexpressions_for_divisions():
    subexpressions = [
        Assignment(s0, 2 / a + 2 / b),
        Assignment(s1, 2 / a + 2 / b + 2 / c),
        Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
        Assignment(s3, 2 / a + 2 / b / c),
        Assignment(x, s1 + s2 + s0 + s3)
    ]
    main = [
        Assignment(f[0], s1 + s2 + s0 + s3)
    ]
    ac = AssignmentCollection(main, subexpressions)
    divs_before_optimisation = ac.operation_count["divs"]
    ac = add_subexpressions_for_divisions(ac)
    divs_after_optimisation = ac.operation_count["divs"]
    assert divs_before_optimisation - divs_after_optimisation == 8
    rhs = []
    for i in range(len(ac.subexpressions)):
        rhs.append(ac.subexpressions[i].rhs)

    assert 1/a in rhs
    assert 1/b in rhs
    assert 1/c in rhs
    assert 1/d in rhs


def test_add_subexpressions_for_constants():
    half = sp.Rational(1,2)
    sqrt_2 = sp.sqrt(2)
    main = [
        Assignment(f[0], half * a + half * b + half * c),
        Assignment(f[1], - half * a - half * b),
        Assignment(f[2], a * sqrt_2 - b * sqrt_2),
        Assignment(f[3], a**2 + b**2)
    ]
    ac = AssignmentCollection(main)
    ac = add_subexpressions_for_constants(ac)
    
    assert len(ac.subexpressions) == 2
    
    half_subexp = None
    sqrt_subexp = None

    for asm in ac.subexpressions:
        if asm.rhs == half:
            half_subexp = asm.lhs
        elif asm.rhs == sqrt_2:
            sqrt_subexp = asm.lhs
        else:
            pytest.fail(f"An unexpected subexpression was encountered: {asm}")
            
    assert half_subexp is not None
    assert sqrt_subexp is not None
    
    for asm in ac.main_assignments[:3]:
        assert isinstance(asm.rhs, sp.Mul)

    assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args)
    assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args)
    assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args)

    #   Do not replace exponents!
    assert ac.main_assignments[3].rhs == a**2 + b**2


def test_add_subexpressions_for_sums():
    subexpressions = [
        Assignment(s0, a + b + c + d),
        Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
        Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
        Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
    ]
    main = [
        Assignment(f[0], s1 + s2 + s0 + s3)
    ]
    ac = AssignmentCollection(main, subexpressions)
    ops_before_optimisation = ac.operation_count
    ac = add_subexpressions_for_sums(ac)
    ops_after_optimisation = ac.operation_count
    assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
    assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
    assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]

    rhs = []
    for i in range(len(ac.subexpressions)):
        rhs.append(ac.subexpressions[i].rhs)

    assert a + b + c + d in rhs
    assert 3 * a * sp.sqrt(x) in rhs


def test_add_subexpressions_for_field_reads():
    s, v = fields("s(5), v(5): double[2D]")
    subexpressions = []

    main = [Assignment(s[0, 0](0), 3 * v[0, 0](0)),
            Assignment(s[0, 0](1), 10 * v[0, 0](1))]

    ac = AssignmentCollection(main, subexpressions)
    assert len(ac.subexpressions) == 0
    ac2 = add_subexpressions_for_field_reads(ac)
    assert len(ac2.subexpressions) == 2
    ac3 = add_subexpressions_for_field_reads(ac, data_type="float32")
    assert len(ac3.subexpressions) == 2
    assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol)
    assert ac3.subexpressions[0].lhs.dtype == BasicType("float32")
    # added check for early out of add_subexpressions_for_field_reads is no fields appear on the rhs (See #92)
    main = [Assignment(s[0, 0](0), 3.0),
            Assignment(s[0, 0](1), 4.0)]

    ac4 = AssignmentCollection(main, subexpressions)
    assert len(ac4.subexpressions) == 0
    ac5 = add_subexpressions_for_field_reads(ac4)
    assert ac5 is not None
    assert ac4 is ac5


@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
def test_sympy_optimizations(target, dtype):
    if target == ps.Target.GPU:
        pytest.importorskip("cupy")
    src, dst = ps.fields(f'src, dst:  {dtype}[2d]')

    assignments = ps.AssignmentCollection({
        src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
    })

    config = pystencils.config.CreateKernelConfig(target=target, default_number_float=dtype)
    ast = ps.create_kernel(assignments, config=config)

    ps.show_code(ast)

    code = ps.get_code_str(ast)
    if dtype == 'float32':
        assert 'expf(' in code
    elif dtype == 'float64':
        assert 'exp(' in code


@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.parametrize('simplification', (True, False))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
def test_evaluate_constant_terms(target, simplification):
    if target == ps.Target.GPU:
        pytest.importorskip("cupy")
    src, dst = ps.fields('src, dst:  float32[2d]')

    # cos of a number will always be simplified
    assignments = ps.AssignmentCollection({
        src[0, 0]: -sp.cos(1) + dst[0, 0]
    })

Markus Holzer's avatar
Markus Holzer committed
    config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
    ast = ps.create_kernel(assignments, config=config)
    code = ps.get_code_str(ast)
    assert 'cos(' not in code