From fd4d1bc0b49ef03e94afae67dd11b66c631940fe Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Fri, 7 Aug 2020 13:19:56 +0200 Subject: [PATCH] Added test cases for pystencils simplifications --- pystencils/simp/__init__.py | 4 +- pystencils/simp/simplifications.py | 6 +- pystencils_tests/test_simplifications.py | 97 ++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 pystencils_tests/test_simplifications.py diff --git a/pystencils/simp/__init__.py b/pystencils/simp/__init__.py index ab0d608fb..dadaa7911 100644 --- a/pystencils/simp/__init__.py +++ b/pystencils/simp/__init__.py @@ -1,7 +1,7 @@ from .assignment_collection import AssignmentCollection from .simplifications import ( add_subexpressions_for_divisions, add_subexpressions_for_field_reads, - apply_on_all_subexpressions, apply_to_all_assignments, + add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments, subexpression_substitution_in_existing_subexpressions, subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list) from .simplificationstrategy import SimplificationStrategy @@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', - 'add_subexpressions_for_field_reads'] + 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads'] diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 5d9b819d5..234b7a373 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] elif isinstance(e1, Node): symbols = e1.symbols_defined else: - raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") + raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.") for lhs in symbols: for c2, e2 in enumerate(assignments): @@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac): addends = [] def contains_sum(term): - if term.func == sp.add.Add: + if term.func == sp.Add: return True if term.is_Atom: return False return any([contains_sum(a) for a in term.args]) def search_addends(term): - if term.func == sp.add.Add: + if term.func == sp.Add: if all([not contains_sum(a) for a in term.args]): addends.extend(term.args) for a in term.args: diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py new file mode 100644 index 000000000..b9f9cc8a1 --- /dev/null +++ b/pystencils_tests/test_simplifications.py @@ -0,0 +1,97 @@ +import sympy as sp + +from pystencils.simp import subexpression_substitution_in_main_assignments +from pystencils.simp import add_subexpressions_for_divisions +from pystencils.simp import add_subexpressions_for_sums +from pystencils.simp import add_subexpressions_for_field_reads +from pystencils import Assignment, AssignmentCollection, fields + +a, b, c, d, x, y, z = sp.symbols("a b c d x y z") +s0, s1, s2, s3 = sp.symbols("s_:4") +f = sp.symbols("f_:9") + + +def test_subexpression_substitution_in_main_assignments(): + subexpressions = [ + Assignment(s0, 2 * a + 2 * b), + Assignment(s1, 2 * a + 2 * b + 2 * c), + Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d), + Assignment(s3, 2 * a + 2 * b * c), + Assignment(x, s1 + s2 + s0 + s3) + ] + main = [ + Assignment(f[0], s1 + s2 + s0 + s3), + Assignment(f[1], s1 + s2 + s0 + s3), + Assignment(f[2], s1 + s2 + s0 + s3), + Assignment(f[3], s1 + s2 + s0 + s3), + Assignment(f[4], s1 + s2 + s0 + s3) + ] + ac = AssignmentCollection(main, subexpressions) + ac = subexpression_substitution_in_main_assignments(ac) + for i in range(0, len(ac.main_assignments)): + assert ac.main_assignments[i].rhs == x + + +def test_add_subexpressions_for_divisions(): + subexpressions = [ + Assignment(s0, 2 / a + 2 / b), + Assignment(s1, 2 / a + 2 / b + 2 / c), + Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d), + Assignment(s3, 2 / a + 2 / b / c), + Assignment(x, s1 + s2 + s0 + s3) + ] + main = [ + Assignment(f[0], s1 + s2 + s0 + s3) + ] + ac = AssignmentCollection(main, subexpressions) + divs_before_optimisation = ac.operation_count["divs"] + ac = add_subexpressions_for_divisions(ac) + divs_after_optimisation = ac.operation_count["divs"] + assert divs_before_optimisation - divs_after_optimisation == 8 + rhs = [] + for i in range(len(ac.subexpressions)): + rhs.append(ac.subexpressions[i].rhs) + + assert 1/a in rhs + assert 1/b in rhs + assert 1/c in rhs + assert 1/d in rhs + + +def test_add_subexpressions_for_sums(): + subexpressions = [ + Assignment(s0, a + b + c + d), + Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c), + Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c), + Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c) + ] + main = [ + Assignment(f[0], s1 + s2 + s0 + s3) + ] + ac = AssignmentCollection(main, subexpressions) + ops_before_optimisation = ac.operation_count + ac = add_subexpressions_for_sums(ac) + ops_after_optimisation = ac.operation_count + assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"] + assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"] + assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"] + + rhs = [] + for i in range(len(ac.subexpressions)): + rhs.append(ac.subexpressions[i].rhs) + + assert a + b + c + d in rhs + assert 3 * a * sp.sqrt(x) in rhs + + +def test_add_subexpressions_for_field_reads(): + s, v = fields("s(5), v(5): double[2D]") + subexpressions = [] + main = [ + Assignment(s[0, 0](0), 3 * v[0, 0](0)), + Assignment(s[0, 0](1), 10 * v[0, 0](1)) + ] + ac = AssignmentCollection(main, subexpressions) + assert len(ac.subexpressions) == 0 + ac = add_subexpressions_for_field_reads(ac) + assert len(ac.subexpressions) == 2 -- GitLab