From 2f213e10e2017d16c22bd3b4ed0154e9dc939d60 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 11 Apr 2018 11:13:49 +0200
Subject: [PATCH] Tests for simplifications + postprocessing + small fixes

---
 .../assignment_collection.py                  |  2 +-
 assignment_collection/simplifications.py      | 44 +++++++++++--------
 .../simplificationstrategy.py                 |  2 +-
 test_simplification_strategy.py               | 43 ++++++++++++++++++
 4 files changed, 70 insertions(+), 21 deletions(-)
 create mode 100644 test_simplification_strategy.py

diff --git a/assignment_collection/assignment_collection.py b/assignment_collection/assignment_collection.py
index 01e7d5a03..9b7e9948b 100644
--- a/assignment_collection/assignment_collection.py
+++ b/assignment_collection/assignment_collection.py
@@ -329,7 +329,7 @@ class AssignmentCollection:
             result += f"\t{eq}\n"
         result += "Main Assignments:\n"
         for eq in self.main_assignments:
-            result += f"{eq}\n"
+            result += f"\t{eq}\n"
         return result
 
 
diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py
index a635e8b33..6e7173f13 100644
--- a/assignment_collection/simplifications.py
+++ b/assignment_collection/simplifications.py
@@ -4,8 +4,10 @@ from pystencils.assignment import Assignment
 from pystencils.assignment_collection.assignment_collection import AssignmentCollection
 from pystencils.sympyextensions import subs_additive
 
+AC = AssignmentCollection
 
-def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
+
+def sympy_cse(ac: AC) -> AC:
     """Searches for common subexpressions inside the equation collection.
 
     Searches is done in both the existing subexpressions as well as the assignments themselves.
@@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
 
 def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
     """Extracts common subexpressions from a list of assignments."""
-    ec = AssignmentCollection([], assignments)
+    ec = AC([], assignments)
     return sympy_cse(ec).all_assignments
 
 
-def apply_to_all_assignments(assignment_collection: AssignmentCollection,
-                             operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
-    """Applies sympy expand operation to all equations in collection."""
-    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
-    return assignment_collection.copy(result)
-
-
-def apply_on_all_subexpressions(ac: AssignmentCollection,
-                                operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
-    """Applies the given operation on all subexpressions of the AssignmentCollection."""
-    result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
-    return ac.copy(ac.main_assignments, result)
-
-
-def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
+def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
     """Goes through the subexpressions list and replaces the term in the following subexpressions."""
     result = []
     for outer_ctr, s in enumerate(ac.subexpressions):
@@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti
     return ac.copy(ac.main_assignments, result)
 
 
-def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection:
+def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
     """Replaces already existing subexpressions in the equations of the assignment_collection."""
     result = []
     for s in ac.main_assignments:
@@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) ->
     return ac.copy(result)
 
 
-def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection:
+def add_subexpressions_for_divisions(ac: AC) -> AC:
     """Introduces subexpressions for all divisions which have no constant in the denominator.
 
     For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
@@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl
     new_symbol_gen = ac.subexpression_symbol_generator
     substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
     return ac.new_with_substitutions(substitutions, True)
+
+
+def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
+    """Applies sympy expand operation to all equations in collection."""
+    def f(assignment_collection: AC) -> AC:
+        result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
+        return assignment_collection.copy(result)
+    f.__name__ = operation.__name__
+    return f
+
+
+def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
+    """Applies the given operation on all subexpressions of the AC."""
+    def f(ac: AC) -> AC:
+        result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
+        return ac.copy(ac.main_assignments, result)
+    f.__name__ = operation.__name__
+    return f
\ No newline at end of file
diff --git a/assignment_collection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py
index a9e9d0d61..a66fcd2cd 100644
--- a/assignment_collection/simplificationstrategy.py
+++ b/assignment_collection/simplificationstrategy.py
@@ -60,7 +60,7 @@ class SimplificationStrategy(object):
                 except ImportError:
                     result = "Name, Adds, Muls, Divs, Runtime\n"
                     for e in self.elements:
-                        result += ",".join(e) + "\n"
+                        result += ",".join([str(tuple_item) for tuple_item in e]) + "\n"
                     return result
 
             def _repr_html_(self):
diff --git a/test_simplification_strategy.py b/test_simplification_strategy.py
new file mode 100644
index 000000000..9c15551dd
--- /dev/null
+++ b/test_simplification_strategy.py
@@ -0,0 +1,43 @@
+import sympy as sp
+from pystencils import Assignment, AssignmentCollection
+from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \
+    subexpression_substitution_in_existing_subexpressions
+
+
+def test_simplification_strategy():
+    a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
+    s0, s1, s2, s3 = sp.symbols("s_:4")
+    a0, a1, a2, a3 = sp.symbols("a_:4")
+
+    subexpressions = [
+        Assignment(s0, 2*a + 2*b),
+        Assignment(s1, 2 * a + 2 * b + 2*c),
+        Assignment(s2, 2 * a + 2 * b + 2*c + 2*d),
+    ]
+    main = [
+        Assignment(a0, s0 + s1),
+        Assignment(a1, s0 + s2),
+        Assignment(a2, s1 + s2),
+    ]
+    ac = AssignmentCollection(main, subexpressions)
+
+    strategy = SimplificationStrategy()
+    strategy.add(subexpression_substitution_in_existing_subexpressions)
+    strategy.add(apply_on_all_subexpressions(sp.factor))
+
+    result = strategy(ac)
+    assert result.operation_count['adds'] == 7
+    assert result.operation_count['muls'] == 5
+    assert result.operation_count['divs'] == 0
+
+    # Trigger display routines, such that they are at least executed
+    report = strategy.show_intermediate_results(ac, symbols=[s0])
+    assert 's_0' in str(report)
+    report = strategy.show_intermediate_results(ac)
+    assert 's_{1}' in report._repr_html_()
+
+    report = strategy.create_simplification_report(ac)
+    assert 'Adds' in str(report)
+    assert 'Adds' in report._repr_html_()
+
+    assert 'factor' in str(strategy)
-- 
GitLab