Skip to content
Snippets Groups Projects
Commit fd4d1bc0 authored by Markus Holzer's avatar Markus Holzer
Browse files

Added test cases for pystencils simplifications

parent 8bc8b39a
No related merge requests found
from .assignment_collection import AssignmentCollection from .assignment_collection import AssignmentCollection
from .simplifications import ( from .simplifications import (
add_subexpressions_for_divisions, add_subexpressions_for_field_reads, add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
apply_on_all_subexpressions, apply_to_all_assignments, add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions, subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list) subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .simplificationstrategy import SimplificationStrategy from .simplificationstrategy import SimplificationStrategy
...@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy', ...@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
'add_subexpressions_for_field_reads'] 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
...@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] ...@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif isinstance(e1, Node): elif isinstance(e1, Node):
symbols = e1.symbols_defined symbols = e1.symbols_defined
else: else:
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols: for lhs in symbols:
for c2, e2 in enumerate(assignments): for c2, e2 in enumerate(assignments):
...@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac): ...@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
addends = [] addends = []
def contains_sum(term): def contains_sum(term):
if term.func == sp.add.Add: if term.func == sp.Add:
return True return True
if term.is_Atom: if term.is_Atom:
return False return False
return any([contains_sum(a) for a in term.args]) return any([contains_sum(a) for a in term.args])
def search_addends(term): def search_addends(term):
if term.func == sp.add.Add: if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]): if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args) addends.extend(term.args)
for a in term.args: for a in term.args:
......
import sympy as sp
from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads
from pystencils import Assignment, AssignmentCollection, fields
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
f = sp.symbols("f_:9")
def test_subexpression_substitution_in_main_assignments():
subexpressions = [
Assignment(s0, 2 * a + 2 * b),
Assignment(s1, 2 * a + 2 * b + 2 * c),
Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
Assignment(s3, 2 * a + 2 * b * c),
Assignment(x, s1 + s2 + s0 + s3)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3),
Assignment(f[1], s1 + s2 + s0 + s3),
Assignment(f[2], s1 + s2 + s0 + s3),
Assignment(f[3], s1 + s2 + s0 + s3),
Assignment(f[4], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
ac = subexpression_substitution_in_main_assignments(ac)
for i in range(0, len(ac.main_assignments)):
assert ac.main_assignments[i].rhs == x
def test_add_subexpressions_for_divisions():
subexpressions = [
Assignment(s0, 2 / a + 2 / b),
Assignment(s1, 2 / a + 2 / b + 2 / c),
Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
Assignment(s3, 2 / a + 2 / b / c),
Assignment(x, s1 + s2 + s0 + s3)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
divs_before_optimisation = ac.operation_count["divs"]
ac = add_subexpressions_for_divisions(ac)
divs_after_optimisation = ac.operation_count["divs"]
assert divs_before_optimisation - divs_after_optimisation == 8
rhs = []
for i in range(len(ac.subexpressions)):
rhs.append(ac.subexpressions[i].rhs)
assert 1/a in rhs
assert 1/b in rhs
assert 1/c in rhs
assert 1/d in rhs
def test_add_subexpressions_for_sums():
subexpressions = [
Assignment(s0, a + b + c + d),
Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
ops_before_optimisation = ac.operation_count
ac = add_subexpressions_for_sums(ac)
ops_after_optimisation = ac.operation_count
assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]
rhs = []
for i in range(len(ac.subexpressions)):
rhs.append(ac.subexpressions[i].rhs)
assert a + b + c + d in rhs
assert 3 * a * sp.sqrt(x) in rhs
def test_add_subexpressions_for_field_reads():
s, v = fields("s(5), v(5): double[2D]")
subexpressions = []
main = [
Assignment(s[0, 0](0), 3 * v[0, 0](0)),
Assignment(s[0, 0](1), 10 * v[0, 0](1))
]
ac = AssignmentCollection(main, subexpressions)
assert len(ac.subexpressions) == 0
ac = add_subexpressions_for_field_reads(ac)
assert len(ac.subexpressions) == 2
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment