Commit 2f213e10 by Martin Bauer

### Tests for simplifications + postprocessing + small fixes

parent d3a1c41a
 ... ... @@ -329,7 +329,7 @@ class AssignmentCollection: result += f"\t{eq}\n" result += "Main Assignments:\n" for eq in self.main_assignments: result += f"{eq}\n" result += f"\t{eq}\n" return result ... ...
 ... ... @@ -4,8 +4,10 @@ from pystencils.assignment import Assignment from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.sympyextensions import subs_additive AC = AssignmentCollection def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: def sympy_cse(ac: AC) -> AC: """Searches for common subexpressions inside the equation collection. Searches is done in both the existing subexpressions as well as the assignments themselves. ... ... @@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments.""" ec = AssignmentCollection([], assignments) ec = AC([], assignments) return sympy_cse(ec).all_assignments def apply_to_all_assignments(assignment_collection: AssignmentCollection, operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: """Applies sympy expand operation to all equations in collection.""" result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] return assignment_collection.copy(result) def apply_on_all_subexpressions(ac: AssignmentCollection, operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection: """Applies the given operation on all subexpressions of the AssignmentCollection.""" result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] return ac.copy(ac.main_assignments, result) def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection: def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC: """Goes through the subexpressions list and replaces the term in the following subexpressions.""" result = [] for outer_ctr, s in enumerate(ac.subexpressions): ... ... @@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti return ac.copy(ac.main_assignments, result) def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection: def subexpression_substitution_in_main_assignments(ac: AC) -> AC: """Replaces already existing subexpressions in the equations of the assignment_collection.""" result = [] for s in ac.main_assignments: ... ... @@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> return ac.copy(result) def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection: def add_subexpressions_for_divisions(ac: AC) -> AC: """Introduces subexpressions for all divisions which have no constant in the denominator. For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced. ... ... @@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl new_symbol_gen = ac.subexpression_symbol_generator substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} return ac.new_with_substitutions(substitutions, True) def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: """Applies sympy expand operation to all equations in collection.""" def f(assignment_collection: AC) -> AC: result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments] return assignment_collection.copy(result) f.__name__ = operation.__name__ return f def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]: """Applies the given operation on all subexpressions of the AC.""" def f(ac: AC) -> AC: result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions] return ac.copy(ac.main_assignments, result) f.__name__ = operation.__name__ return f \ No newline at end of file
 ... ... @@ -60,7 +60,7 @@ class SimplificationStrategy(object): except ImportError: result = "Name, Adds, Muls, Divs, Runtime\n" for e in self.elements: result += ",".join(e) + "\n" result += ",".join([str(tuple_item) for tuple_item in e]) + "\n" return result def _repr_html_(self): ... ...
 import sympy as sp from pystencils import Assignment, AssignmentCollection from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \ subexpression_substitution_in_existing_subexpressions def test_simplification_strategy(): a, b, c, d, x, y, z = sp.symbols("a b c d x y z") s0, s1, s2, s3 = sp.symbols("s_:4") a0, a1, a2, a3 = sp.symbols("a_:4") subexpressions = [ Assignment(s0, 2*a + 2*b), Assignment(s1, 2 * a + 2 * b + 2*c), Assignment(s2, 2 * a + 2 * b + 2*c + 2*d), ] main = [ Assignment(a0, s0 + s1), Assignment(a1, s0 + s2), Assignment(a2, s1 + s2), ] ac = AssignmentCollection(main, subexpressions) strategy = SimplificationStrategy() strategy.add(subexpression_substitution_in_existing_subexpressions) strategy.add(apply_on_all_subexpressions(sp.factor)) result = strategy(ac) assert result.operation_count['adds'] == 7 assert result.operation_count['muls'] == 5 assert result.operation_count['divs'] == 0 # Trigger display routines, such that they are at least executed report = strategy.show_intermediate_results(ac, symbols=[s0]) assert 's_0' in str(report) report = strategy.show_intermediate_results(ac) assert 's_{1}' in report._repr_html_() report = strategy.create_simplification_report(ac) assert 'Adds' in str(report) assert 'Adds' in report._repr_html_() assert 'factor' in str(strategy)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!