From fd4d1bc0b49ef03e94afae67dd11b66c631940fe Mon Sep 17 00:00:00 2001
From: markus holzer <markus.holzer@fau.de>
Date: Fri, 7 Aug 2020 13:19:56 +0200
Subject: [PATCH] Added test cases for pystencils simplifications

---
 pystencils/simp/__init__.py              |  4 +-
 pystencils/simp/simplifications.py       |  6 +-
 pystencils_tests/test_simplifications.py | 97 ++++++++++++++++++++++++
 3 files changed, 102 insertions(+), 5 deletions(-)
 create mode 100644 pystencils_tests/test_simplifications.py

diff --git a/pystencils/simp/__init__.py b/pystencils/simp/__init__.py
index ab0d608fb..dadaa7911 100644
--- a/pystencils/simp/__init__.py
+++ b/pystencils/simp/__init__.py
@@ -1,7 +1,7 @@
 from .assignment_collection import AssignmentCollection
 from .simplifications import (
     add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
-    apply_on_all_subexpressions, apply_to_all_assignments,
+    add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
     subexpression_substitution_in_existing_subexpressions,
     subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
 from .simplificationstrategy import SimplificationStrategy
@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
            'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
            'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
            'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
-           'add_subexpressions_for_field_reads']
+           'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py
index 5d9b819d5..234b7a373 100644
--- a/pystencils/simp/simplifications.py
+++ b/pystencils/simp/simplifications.py
@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
         elif isinstance(e1, Node):
             symbols = e1.symbols_defined
         else:
-            raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.")
+            raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
 
         for lhs in symbols:
             for c2, e2 in enumerate(assignments):
@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
     addends = []
 
     def contains_sum(term):
-        if term.func == sp.add.Add:
+        if term.func == sp.Add:
             return True
         if term.is_Atom:
             return False
         return any([contains_sum(a) for a in term.args])
 
     def search_addends(term):
-        if term.func == sp.add.Add:
+        if term.func == sp.Add:
             if all([not contains_sum(a) for a in term.args]):
                 addends.extend(term.args)
         for a in term.args:
diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py
new file mode 100644
index 000000000..b9f9cc8a1
--- /dev/null
+++ b/pystencils_tests/test_simplifications.py
@@ -0,0 +1,97 @@
+import sympy as sp
+
+from pystencils.simp import subexpression_substitution_in_main_assignments
+from pystencils.simp import add_subexpressions_for_divisions
+from pystencils.simp import add_subexpressions_for_sums
+from pystencils.simp import add_subexpressions_for_field_reads
+from pystencils import Assignment, AssignmentCollection, fields
+
+a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
+s0, s1, s2, s3 = sp.symbols("s_:4")
+f = sp.symbols("f_:9")
+
+
+def test_subexpression_substitution_in_main_assignments():
+    subexpressions = [
+        Assignment(s0, 2 * a + 2 * b),
+        Assignment(s1, 2 * a + 2 * b + 2 * c),
+        Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
+        Assignment(s3, 2 * a + 2 * b * c),
+        Assignment(x, s1 + s2 + s0 + s3)
+    ]
+    main = [
+        Assignment(f[0], s1 + s2 + s0 + s3),
+        Assignment(f[1], s1 + s2 + s0 + s3),
+        Assignment(f[2], s1 + s2 + s0 + s3),
+        Assignment(f[3], s1 + s2 + s0 + s3),
+        Assignment(f[4], s1 + s2 + s0 + s3)
+    ]
+    ac = AssignmentCollection(main, subexpressions)
+    ac = subexpression_substitution_in_main_assignments(ac)
+    for i in range(0, len(ac.main_assignments)):
+        assert ac.main_assignments[i].rhs == x
+
+
+def test_add_subexpressions_for_divisions():
+    subexpressions = [
+        Assignment(s0, 2 / a + 2 / b),
+        Assignment(s1, 2 / a + 2 / b + 2 / c),
+        Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
+        Assignment(s3, 2 / a + 2 / b / c),
+        Assignment(x, s1 + s2 + s0 + s3)
+    ]
+    main = [
+        Assignment(f[0], s1 + s2 + s0 + s3)
+    ]
+    ac = AssignmentCollection(main, subexpressions)
+    divs_before_optimisation = ac.operation_count["divs"]
+    ac = add_subexpressions_for_divisions(ac)
+    divs_after_optimisation = ac.operation_count["divs"]
+    assert divs_before_optimisation - divs_after_optimisation == 8
+    rhs = []
+    for i in range(len(ac.subexpressions)):
+        rhs.append(ac.subexpressions[i].rhs)
+
+    assert 1/a in rhs
+    assert 1/b in rhs
+    assert 1/c in rhs
+    assert 1/d in rhs
+
+
+def test_add_subexpressions_for_sums():
+    subexpressions = [
+        Assignment(s0, a + b + c + d),
+        Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
+        Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
+        Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
+    ]
+    main = [
+        Assignment(f[0], s1 + s2 + s0 + s3)
+    ]
+    ac = AssignmentCollection(main, subexpressions)
+    ops_before_optimisation = ac.operation_count
+    ac = add_subexpressions_for_sums(ac)
+    ops_after_optimisation = ac.operation_count
+    assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
+    assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
+    assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]
+
+    rhs = []
+    for i in range(len(ac.subexpressions)):
+        rhs.append(ac.subexpressions[i].rhs)
+
+    assert a + b + c + d in rhs
+    assert 3 * a * sp.sqrt(x) in rhs
+
+
+def test_add_subexpressions_for_field_reads():
+    s, v = fields("s(5), v(5): double[2D]")
+    subexpressions = []
+    main = [
+        Assignment(s[0, 0](0), 3 * v[0, 0](0)),
+        Assignment(s[0, 0](1), 10 * v[0, 0](1))
+    ]
+    ac = AssignmentCollection(main, subexpressions)
+    assert len(ac.subexpressions) == 0
+    ac = add_subexpressions_for_field_reads(ac)
+    assert len(ac.subexpressions) == 2
-- 
GitLab