Commit 2f213e10 authored by Martin Bauer's avatar Martin Bauer
Browse files

Tests for simplifications + postprocessing + small fixes

parent d3a1c41a
...@@ -329,7 +329,7 @@ class AssignmentCollection: ...@@ -329,7 +329,7 @@ class AssignmentCollection:
result += f"\t{eq}\n" result += f"\t{eq}\n"
result += "Main Assignments:\n" result += "Main Assignments:\n"
for eq in self.main_assignments: for eq in self.main_assignments:
result += f"{eq}\n" result += f"\t{eq}\n"
return result return result
......
...@@ -4,8 +4,10 @@ from pystencils.assignment import Assignment ...@@ -4,8 +4,10 @@ from pystencils.assignment import Assignment
from pystencils.assignment_collection.assignment_collection import AssignmentCollection from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.sympyextensions import subs_additive from pystencils.sympyextensions import subs_additive
AC = AssignmentCollection
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
def sympy_cse(ac: AC) -> AC:
"""Searches for common subexpressions inside the equation collection. """Searches for common subexpressions inside the equation collection.
Searches is done in both the existing subexpressions as well as the assignments themselves. Searches is done in both the existing subexpressions as well as the assignments themselves.
...@@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection: ...@@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments.""" """Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection([], assignments) ec = AC([], assignments)
return sympy_cse(ec).all_assignments return sympy_cse(ec).all_assignments
def apply_to_all_assignments(assignment_collection: AssignmentCollection, def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
def apply_on_all_subexpressions(ac: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies the given operation on all subexpressions of the AssignmentCollection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
"""Goes through the subexpressions list and replaces the term in the following subexpressions.""" """Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = [] result = []
for outer_ctr, s in enumerate(ac.subexpressions): for outer_ctr, s in enumerate(ac.subexpressions):
...@@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti ...@@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti
return ac.copy(ac.main_assignments, result) return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection: def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
"""Replaces already existing subexpressions in the equations of the assignment_collection.""" """Replaces already existing subexpressions in the equations of the assignment_collection."""
result = [] result = []
for s in ac.main_assignments: for s in ac.main_assignments:
...@@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> ...@@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) ->
return ac.copy(result) return ac.copy(result)
def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection: def add_subexpressions_for_divisions(ac: AC) -> AC:
"""Introduces subexpressions for all divisions which have no constant in the denominator. """Introduces subexpressions for all divisions which have no constant in the denominator.
For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced. For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
...@@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl ...@@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl
new_symbol_gen = ac.subexpression_symbol_generator new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)} substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, True) return ac.new_with_substitutions(substitutions, True)
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
"""Applies sympy expand operation to all equations in collection."""
def f(assignment_collection: AC) -> AC:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
f.__name__ = operation.__name__
return f
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
"""Applies the given operation on all subexpressions of the AC."""
def f(ac: AC) -> AC:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
f.__name__ = operation.__name__
return f
\ No newline at end of file
...@@ -60,7 +60,7 @@ class SimplificationStrategy(object): ...@@ -60,7 +60,7 @@ class SimplificationStrategy(object):
except ImportError: except ImportError:
result = "Name, Adds, Muls, Divs, Runtime\n" result = "Name, Adds, Muls, Divs, Runtime\n"
for e in self.elements: for e in self.elements:
result += ",".join(e) + "\n" result += ",".join([str(tuple_item) for tuple_item in e]) + "\n"
return result return result
def _repr_html_(self): def _repr_html_(self):
......
import sympy as sp
from pystencils import Assignment, AssignmentCollection
from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \
subexpression_substitution_in_existing_subexpressions
def test_simplification_strategy():
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
a0, a1, a2, a3 = sp.symbols("a_:4")
subexpressions = [
Assignment(s0, 2*a + 2*b),
Assignment(s1, 2 * a + 2 * b + 2*c),
Assignment(s2, 2 * a + 2 * b + 2*c + 2*d),
]
main = [
Assignment(a0, s0 + s1),
Assignment(a1, s0 + s2),
Assignment(a2, s1 + s2),
]
ac = AssignmentCollection(main, subexpressions)
strategy = SimplificationStrategy()
strategy.add(subexpression_substitution_in_existing_subexpressions)
strategy.add(apply_on_all_subexpressions(sp.factor))
result = strategy(ac)
assert result.operation_count['adds'] == 7
assert result.operation_count['muls'] == 5
assert result.operation_count['divs'] == 0
# Trigger display routines, such that they are at least executed
report = strategy.show_intermediate_results(ac, symbols=[s0])
assert 's_0' in str(report)
report = strategy.show_intermediate_results(ac)
assert 's_{1}' in report._repr_html_()
report = strategy.create_simplification_report(ac)
assert 'Adds' in str(report)
assert 'Adds' in report._repr_html_()
assert 'factor' in str(strategy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment