From 2f213e10e2017d16c22bd3b4ed0154e9dc939d60 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 11 Apr 2018 11:13:49 +0200 Subject: [PATCH] Tests for simplifications + postprocessing + small fixes --- .../assignment_collection.py | 2 +- assignment_collection/simplifications.py | 44 +++++++++++-------- .../simplificationstrategy.py | 2 +- test_simplification_strategy.py | 43 ++++++++++++++++++ 4 files changed, 70 insertions(+), 21 deletions(-) create mode 100644 test_simplification_strategy.py diff --git a/assignment_collection/assignment_collection.py b/assignment_collection/assignment_collection.py index 01e7d5a03..9b7e9948b 100644 --- a/assignment_collection/assignment_collection.py +++ b/assignment_collection/assignment_collection.py @@ -329,7 +329,7 @@ class AssignmentCollection: result += f"\t{eq}\n" result += "Main Assignments:\n" for eq in self.main_assignments: - result += f"{eq}\n" + result += f"\t{eq}\n" return result diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py index a635e8b33..6e7173f13 100644 --- a/assignment_collection/simplifications.py +++ b/assignment_collection/simplifications.py @@ -4,8 +4,10 @@ from pystencils.assignment import Assignment from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.sympyextensions import subs_additive +AC = AssignmentCollection -def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: + +def sympy_cse(ac: AC) -> AC: """Searches for common subexpressions inside the equation collection. Searches is done in both the existing subexpressions as well as the assignments themselves. @@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments.""" - ec = AssignmentCollection([], assignments) + ec = AC([], assignments) return sympy_cse(ec).all_assignments -def apply_to_all_assignments(assignment_collection: AssignmentCollection, - operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: - """Applies sympy expand operation to all equations in collection.""" - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] - return assignment_collection.copy(result) - - -def apply_on_all_subexpressions(ac: AssignmentCollection, - operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: - """Applies the given operation on all subexpressions of the AssignmentCollection.""" - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] - return ac.copy(ac.main_assignments, result) - - -def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection: +def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: """Goes through the subexpressions list and replaces the term in the following subexpressions.""" result = [] for outer_ctr, s in enumerate(ac.subexpressions): @@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti return ac.copy(ac.main_assignments, result) -def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection: +def subexpression_substitution_in_main_assignments(ac: AC) -> AC: """Replaces already existing subexpressions in the equations of the assignment_collection.""" result = [] for s in ac.main_assignments: @@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> return ac.copy(result) -def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection: +def add_subexpressions_for_divisions(ac: AC) -> AC: """Introduces subexpressions for all divisions which have no constant in the denominator. For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced. @@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl new_symbol_gen = ac.subexpression_symbol_generator substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} return ac.new_with_substitutions(substitutions, True) + + +def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: + """Applies sympy expand operation to all equations in collection.""" + def f(assignment_collection: AC) -> AC: + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] + return assignment_collection.copy(result) + f.__name__ = operation.__name__ + return f + + +def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: + """Applies the given operation on all subexpressions of the AC.""" + def f(ac: AC) -> AC: + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] + return ac.copy(ac.main_assignments, result) + f.__name__ = operation.__name__ + return f \ No newline at end of file diff --git a/assignment_collection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py index a9e9d0d61..a66fcd2cd 100644 --- a/assignment_collection/simplificationstrategy.py +++ b/assignment_collection/simplificationstrategy.py @@ -60,7 +60,7 @@ class SimplificationStrategy(object): except ImportError: result = "Name, Adds, Muls, Divs, Runtime\n" for e in self.elements: - result += ",".join(e) + "\n" + result += ",".join([str(tuple_item) for tuple_item in e]) + "\n" return result def _repr_html_(self): diff --git a/test_simplification_strategy.py b/test_simplification_strategy.py new file mode 100644 index 000000000..9c15551dd --- /dev/null +++ b/test_simplification_strategy.py @@ -0,0 +1,43 @@ +import sympy as sp +from pystencils import Assignment, AssignmentCollection +from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \ + subexpression_substitution_in_existing_subexpressions + + +def test_simplification_strategy(): + a, b, c, d, x, y, z = sp.symbols("a b c d x y z") + s0, s1, s2, s3 = sp.symbols("s_:4") + a0, a1, a2, a3 = sp.symbols("a_:4") + + subexpressions = [ + Assignment(s0, 2*a + 2*b), + Assignment(s1, 2 * a + 2 * b + 2*c), + Assignment(s2, 2 * a + 2 * b + 2*c + 2*d), + ] + main = [ + Assignment(a0, s0 + s1), + Assignment(a1, s0 + s2), + Assignment(a2, s1 + s2), + ] + ac = AssignmentCollection(main, subexpressions) + + strategy = SimplificationStrategy() + strategy.add(subexpression_substitution_in_existing_subexpressions) + strategy.add(apply_on_all_subexpressions(sp.factor)) + + result = strategy(ac) + assert result.operation_count['adds'] == 7 + assert result.operation_count['muls'] == 5 + assert result.operation_count['divs'] == 0 + + # Trigger display routines, such that they are at least executed + report = strategy.show_intermediate_results(ac, symbols=[s0]) + assert 's_0' in str(report) + report = strategy.show_intermediate_results(ac) + assert 's_{1}' in report._repr_html_() + + report = strategy.create_simplification_report(ac) + assert 'Adds' in str(report) + assert 'Adds' in report._repr_html_() + + assert 'factor' in str(strategy) -- GitLab