Commit 2f213e10 authored by Martin Bauer's avatar Martin Bauer
Browse files

Tests for simplifications + postprocessing + small fixes

parent d3a1c41a
......@@ -329,7 +329,7 @@ class AssignmentCollection:
result += f"\t{eq}\n"
result += "Main Assignments:\n"
for eq in self.main_assignments:
result += f"{eq}\n"
result += f"\t{eq}\n"
return result
......
......@@ -4,8 +4,10 @@ from pystencils.assignment import Assignment
from pystencils.assignment_collection.assignment_collection import AssignmentCollection
from pystencils.sympyextensions import subs_additive
AC = AssignmentCollection
def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
def sympy_cse(ac: AC) -> AC:
"""Searches for common subexpressions inside the equation collection.
Searches is done in both the existing subexpressions as well as the assignments themselves.
......@@ -29,25 +31,11 @@ def sympy_cse(ac: AssignmentCollection) -> AssignmentCollection:
def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]:
"""Extracts common subexpressions from a list of assignments."""
ec = AssignmentCollection([], assignments)
ec = AC([], assignments)
return sympy_cse(ec).all_assignments
def apply_to_all_assignments(assignment_collection: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies sympy expand operation to all equations in collection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
def apply_on_all_subexpressions(ac: AssignmentCollection,
operation: Callable[[sp.Expr], sp.Expr]) -> AssignmentCollection:
"""Applies the given operation on all subexpressions of the AssignmentCollection."""
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollection) -> AssignmentCollection:
def subexpression_substitution_in_existing_subexpressions(ac: AC) -> AC:
"""Goes through the subexpressions list and replaces the term in the following subexpressions."""
result = []
for outer_ctr, s in enumerate(ac.subexpressions):
......@@ -61,7 +49,7 @@ def subexpression_substitution_in_existing_subexpressions(ac: AssignmentCollecti
return ac.copy(ac.main_assignments, result)
def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) -> AssignmentCollection:
def subexpression_substitution_in_main_assignments(ac: AC) -> AC:
"""Replaces already existing subexpressions in the equations of the assignment_collection."""
result = []
for s in ac.main_assignments:
......@@ -72,7 +60,7 @@ def subexpression_substitution_in_main_assignments(ac: AssignmentCollection) ->
return ac.copy(result)
def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentCollection:
def add_subexpressions_for_divisions(ac: AC) -> AC:
"""Introduces subexpressions for all divisions which have no constant in the denominator.
For example :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced.
......@@ -93,3 +81,21 @@ def add_subexpressions_for_divisions(ac: AssignmentCollection) -> AssignmentColl
new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {divisor: new_symbol for new_symbol, divisor in zip(new_symbol_gen, divisors)}
return ac.new_with_substitutions(substitutions, True)
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
"""Applies sympy expand operation to all equations in collection."""
def f(assignment_collection: AC) -> AC:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in assignment_collection.main_assignments]
return assignment_collection.copy(result)
f.__name__ = operation.__name__
return f
def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]) -> Callable[[AC], AC]:
"""Applies the given operation on all subexpressions of the AC."""
def f(ac: AC) -> AC:
result = [Assignment(eq.lhs, operation(eq.rhs)) for eq in ac.subexpressions]
return ac.copy(ac.main_assignments, result)
f.__name__ = operation.__name__
return f
\ No newline at end of file
......@@ -60,7 +60,7 @@ class SimplificationStrategy(object):
except ImportError:
result = "Name, Adds, Muls, Divs, Runtime\n"
for e in self.elements:
result += ",".join(e) + "\n"
result += ",".join([str(tuple_item) for tuple_item in e]) + "\n"
return result
def _repr_html_(self):
......
import sympy as sp
from pystencils import Assignment, AssignmentCollection
from pystencils.assignment_collection import SimplificationStrategy, apply_on_all_subexpressions, \
subexpression_substitution_in_existing_subexpressions
def test_simplification_strategy():
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
a0, a1, a2, a3 = sp.symbols("a_:4")
subexpressions = [
Assignment(s0, 2*a + 2*b),
Assignment(s1, 2 * a + 2 * b + 2*c),
Assignment(s2, 2 * a + 2 * b + 2*c + 2*d),
]
main = [
Assignment(a0, s0 + s1),
Assignment(a1, s0 + s2),
Assignment(a2, s1 + s2),
]
ac = AssignmentCollection(main, subexpressions)
strategy = SimplificationStrategy()
strategy.add(subexpression_substitution_in_existing_subexpressions)
strategy.add(apply_on_all_subexpressions(sp.factor))
result = strategy(ac)
assert result.operation_count['adds'] == 7
assert result.operation_count['muls'] == 5
assert result.operation_count['divs'] == 0
# Trigger display routines, such that they are at least executed
report = strategy.show_intermediate_results(ac, symbols=[s0])
assert 's_0' in str(report)
report = strategy.show_intermediate_results(ac)
assert 's_{1}' in report._repr_html_()
report = strategy.create_simplification_report(ac)
assert 'Adds' in str(report)
assert 'Adds' in report._repr_html_()
assert 'factor' in str(strategy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment