test_simplification_strategy.py 2.68 KB
Newer Older
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2

3
import pystencils as ps
4
from pystencils import Assignment, AssignmentCollection
Martin Bauer's avatar
Martin Bauer committed
5
6
7
from pystencils.simp import (
    SimplificationStrategy, apply_on_all_subexpressions,
    subexpression_substitution_in_existing_subexpressions)
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


def test_simplification_strategy():
    a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
    s0, s1, s2, s3 = sp.symbols("s_:4")
    a0, a1, a2, a3 = sp.symbols("a_:4")

    subexpressions = [
        Assignment(s0, 2 * a + 2 * b),
        Assignment(s1, 2 * a + 2 * b + 2 * c),
        Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
    ]
    main = [
        Assignment(a0, s0 + s1),
        Assignment(a1, s0 + s2),
        Assignment(a2, s1 + s2),
    ]
    ac = AssignmentCollection(main, subexpressions)

    strategy = SimplificationStrategy()
    strategy.add(subexpression_substitution_in_existing_subexpressions)
    strategy.add(apply_on_all_subexpressions(sp.factor))

    result = strategy(ac)
    assert result.operation_count['adds'] == 7
    assert result.operation_count['muls'] == 5
    assert result.operation_count['divs'] == 0

    # Trigger display routines, such that they are at least executed
    report = strategy.show_intermediate_results(ac, symbols=[s0])
    assert 's_0' in str(report)
    report = strategy.show_intermediate_results(ac)
    assert 's_{1}' in report._repr_html_()

    report = strategy.create_simplification_report(ac)
    assert 'Adds' in str(report)
    assert 'Adds' in report._repr_html_()

    assert 'factor' in str(strategy)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


def test_split_inner_loop():
    dst = ps.fields('dst(8): double[2D]')
    s = sp.symbols('s_:8')
    x = sp.symbols('x')
    subexpressions = []
    main = [
        Assignment(dst[0, 0](0), s[0]),
        Assignment(dst[0, 0](1), s[1]),
        Assignment(dst[0, 0](2), s[2]),
        Assignment(dst[0, 0](3), s[3]),
        Assignment(dst[0, 0](4), s[4]),
        Assignment(dst[0, 0](5), s[5]),
        Assignment(dst[0, 0](6), s[6]),
        Assignment(dst[0, 0](7), s[7]),
        Assignment(x, sum(s))
    ]
    ac = AssignmentCollection(main, subexpressions)
    split_groups = [[dst[0, 0](0), dst[0, 0](1)],
                    [dst[0, 0](2), dst[0, 0](3)],
                    [dst[0, 0](4), dst[0, 0](5)],
                    [dst[0, 0](6), dst[0, 0](7), x]]
    ac.simplification_hints['split_groups'] = split_groups
    ast = ps.create_kernel(ac)

    code = ps.get_code_str(ast)
    # we have four inner loops as indicated in split groups (4 elements) plus one outer loop
    assert code.count('for') == 5

    ac = AssignmentCollection(main, subexpressions)
    ast = ps.create_kernel(ac)

    code = ps.get_code_str(ast)
    # one inner loop and one outer loop
    assert code.count('for') == 2