diff --git a/assignment_collection/assignment_collection.py b/assignment_collection/assignment_collection.py index 01e7d5a037744323d7fdaba77a3e31146f2aa6e0..9b7e9948b40ab42e7850e413ae629f9974418679 100644 --- a/assignment_collection/assignment_collection.py +++ b/assignment_collection/assignment_collection.py @@ -329,7 +329,7 @@ class AssignmentCollection: result += f"\t{eq}\n" result += "Main Assignments:\n" for eq in self.main_assignments: - result += f"{eq}\n" + result += f"\t{eq}\n" return result diff --git a/assignment_collection/simplifications.py b/assignment_collection/simplifications.py index a635e8b33ab8d8438cea1d55e8cf8f8c86c00b6f..6e7173f13236d38e931184e26007dc0e57b3f8e5 100644 --- a/assignment_collection/simplifications.py +++ b/assignment_collection/simplifications.py @@ -4,8 +4,10 @@ from pystencils.assignment import Assignment from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.sympyextensions import subs_additive +AC = AssignmentCollection -def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: + +def sympy_cse(ac: AC) -> AC: """Searches for common subexpressions inside the equation collection. Searches is done in both the existing subexpressions as well as the assignments themselves. @@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments.""" - ec = AssignmentCollection([], assignments) + ec = AC([], assignments) return sympy_cse(ec).all_assignments -def apply_to_all_assignments(assignment_collection: AssignmentCollection, - operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: - """Applies sympy expand operation to all equations in collection.""" - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] - return assignment_collection.copy(result) - - -def apply_on_all_subexpressions(ac: AssignmentCollection, - operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: - """Applies the given operation on all subexpressions of the AssignmentCollection.""" - result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] - return ac.copy(ac.main_assignments, result) - - -def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection: +def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: """Goes through the subexpressions list and replaces the term in the following subexpressions.""" result = [] for outer_ctr, s in enumerate(ac.subexpressions): @@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti return ac.copy(ac.main_assignments, result) -def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection: +def subexpression_substitution_in_main_assignments(ac: AC) -> AC: """Replaces already existing subexpressions in the equations of the assignment_collection.""" result = [] for s in ac.main_assignments: @@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> return ac.copy(result) -def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection: +def add_subexpressions_for_divisions(ac: AC) -> AC: """Introduces subexpressions for all divisions which have no constant in the denominator. For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced. @@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl new_symbol_gen = ac.subexpression_symbol_generator substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} return ac.new_with_substitutions(substitutions, True) + + +def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: + """Applies sympy expand operation to all equations in collection.""" + def f(assignment_collection: AC) -> AC: + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] + return assignment_collection.copy(result) + f.__name__ = operation.__name__ + return f + + +def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: + """Applies the given operation on all subexpressions of the AC.""" + def f(ac: AC) -> AC: + result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] + return ac.copy(ac.main_assignments, result) + f.__name__ = operation.__name__ + return f \ No newline at end of file diff --git a/assignment_collection/simplificationstrategy.py b/assignment_collection/simplificationstrategy.py index a9e9d0d61bf7b0a64b4b9b4f258f68393196c279..a66fcd2cdccb08c056ff8d8cfe42afca09e588f8 100644 --- a/assignment_collection/simplificationstrategy.py +++ b/assignment_collection/simplificationstrategy.py @@ -60,7 +60,7 @@ class SimplificationStrategy(object): except ImportError: result = "Name, Adds, Muls, Divs, Runtime\n" for e in self.elements: - result += ",".join(e) + "\n" + result += ",".join([str(tuple_item) for tuple_item in e]) + "\n" return result def _repr_html_(self): diff --git a/test_simplification_strategy.py b/test_simplification_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9c15551ddb7ae1153a7ff30caf9a6378b045de8c --- /dev/null +++ b/test_simplification_strategy.py @@ -0,0 +1,43 @@ +import sympy as sp +from pystencils import Assignment, AssignmentCollection +from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \ + subexpression_substitution_in_existing_subexpressions + + +def test_simplification_strategy(): + a, b, c, d, x, y, z = sp.symbols("a b c d x y z") + s0, s1, s2, s3 = sp.symbols("s_:4") + a0, a1, a2, a3 = sp.symbols("a_:4") + + subexpressions = [ + Assignment(s0, 2*a + 2*b), + Assignment(s1, 2 * a + 2 * b + 2*c), + Assignment(s2, 2 * a + 2 * b + 2*c + 2*d), + ] + main = [ + Assignment(a0, s0 + s1), + Assignment(a1, s0 + s2), + Assignment(a2, s1 + s2), + ] + ac = AssignmentCollection(main, subexpressions) + + strategy = SimplificationStrategy() + strategy.add(subexpression_substitution_in_existing_subexpressions) + strategy.add(apply_on_all_subexpressions(sp.factor)) + + result = strategy(ac) + assert result.operation_count['adds'] == 7 + assert result.operation_count['muls'] == 5 + assert result.operation_count['divs'] == 0 + + # Trigger display routines, such that they are at least executed + report = strategy.show_intermediate_results(ac, symbols=[s0]) + assert 's_0' in str(report) + report = strategy.show_intermediate_results(ac) + assert 's_{1}' in report._repr_html_() + + report = strategy.create_simplification_report(ac) + assert 'Adds' in str(report) + assert 'Adds' in report._repr_html_() + + assert 'factor' in str(strategy)