Commit 24ef1d2f authored by Markus Holzer's avatar Markus Holzer
Browse files

Added test cases for Assignmentcollection

parent fd4d1bc0
...@@ -6,8 +6,7 @@ import sympy as sp ...@@ -6,8 +6,7 @@ import sympy as sp
import pystencils import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.simp.simplifications import ( from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
...@@ -263,7 +262,7 @@ class AssignmentCollection: ...@@ -263,7 +262,7 @@ class AssignmentCollection:
own_definitions = set([e.lhs for e in self.main_assignments]) own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \ assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols" "Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {} substitution_dict = {}
...@@ -334,7 +333,7 @@ class AssignmentCollection: ...@@ -334,7 +333,7 @@ class AssignmentCollection:
kept_subexpressions = [] kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep: if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {} substitution_dict = {}
kept_subexpressions = self.subexpressions[0] kept_subexpressions.append(self.subexpressions[0])
else: else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
......
import pytest import pytest
import sympy as sp import sympy as sp
import pystencils as ps
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
from pystencils.astnodes import Conditional from pystencils.astnodes import Conditional
from pystencils.simp.assignment_collection import SymbolGen from pystencils.simp.assignment_collection import SymbolGen
a, b, c = sp.symbols("a b c")
x, y, z, t = sp.symbols("x y z t")
symbol_gen = SymbolGen("a")
f = ps.fields("f(2) : [2D]")
d = ps.fields("d(2) : [2D]")
def test_assignment_collection():
x, y, z, t = sp.symbols("x y z t")
symbol_gen = SymbolGen("a")
def test_assignment_collection():
ac = AssignmentCollection([Assignment(z, x + y)], ac = AssignmentCollection([Assignment(z, x + y)],
[], subexpression_symbol_generator=symbol_gen) [], subexpression_symbol_generator=symbol_gen)
...@@ -32,10 +36,6 @@ def test_assignment_collection(): ...@@ -32,10 +36,6 @@ def test_assignment_collection():
def test_free_and_defined_symbols(): def test_free_and_defined_symbols():
x, y, z, t = sp.symbols("x y z t")
a, b = sp.symbols("a b")
symbol_gen = SymbolGen("a")
ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))], ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))],
[], subexpression_symbol_generator=symbol_gen) [], subexpression_symbol_generator=symbol_gen)
...@@ -45,35 +45,128 @@ def test_free_and_defined_symbols(): ...@@ -45,35 +45,128 @@ def test_free_and_defined_symbols():
def test_vector_assignments(): def test_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
assignments = ps.Assignment(sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3]))
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3]))
print(assignments) print(assignments)
def test_wrong_vector_assignments(): def test_wrong_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b = sp.symbols("a b")
with pytest.raises(AssertionError, with pytest.raises(AssertionError,
match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'): match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3])) ps.Assignment(sp.Matrix([a, b]), sp.Matrix([1, 2, 3]))
def test_vector_assignment_collection(): def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps y_m, x_m = sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3])
import sympy as sp assignments = ps.AssignmentCollection({y_m: x_m})
a, b, c = sp.symbols("a b c")
y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3])
assignments = ps.AssignmentCollection({y: x})
print(assignments) print(assignments)
assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) assignments = ps.AssignmentCollection([ps.Assignment(y_m, x_m)])
print(assignments) print(assignments)
def test_new_with_substitutions():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
subs_dict = {f[0, 0](0): d[0, 0](0), f[0, 0](1): d[0, 0](1)}
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=True,
sort_topologically=True)
assert subs_ac.main_assignments[0].lhs == d[0, 0](0)
assert subs_ac.main_assignments[1].lhs == d[0, 0](1)
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=False,
sort_topologically=True)
assert subs_ac.main_assignments[0].lhs == f[0, 0](0)
assert subs_ac.main_assignments[1].lhs == f[0, 0](1)
subs_dict = {a * b: sp.symbols('xi')}
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=False,
sort_topologically=True)
assert subs_ac.main_assignments[0].rhs == sp.symbols('xi')
assert len(subs_ac.subexpressions) == 0
subs_ac = ac.new_with_substitutions(subs_dict,
add_substitutions_as_subexpressions=True,
substitute_on_lhs=False,
sort_topologically=True)
assert subs_ac.main_assignments[0].rhs == sp.symbols('xi')
assert len(subs_ac.subexpressions) == 1
assert subs_ac.subexpressions[0].lhs == sp.symbols('xi')
def test_copy():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
ac2 = ac.copy()
assert ac2 == ac
def test_set_expressions():
a1 = ps.Assignment(f[0, 0](0), a * b)
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a1, a2], subexpressions=[])
ac.set_main_assignments_from_dict({d[0, 0](0): b * c})
assert len(ac.main_assignments) == 1
assert ac.main_assignments[0] == ps.Assignment(d[0, 0](0), b * c)
ac.set_sub_expressions_from_dict({sp.symbols('xi'): a * b})
assert len(ac.subexpressions) == 1
assert ac.subexpressions[0] == ps.Assignment(sp.symbols('xi'), a * b)
ac = ac.new_without_subexpressions(subexpressions_to_keep={sp.symbols('xi')})
assert ac.subexpressions[0] == ps.Assignment(sp.symbols('xi'), a * b)
ac = ac.new_without_unused_subexpressions()
assert len(ac.subexpressions) == 0
ac2 = ac.new_without_subexpressions()
assert ac == ac2
def test_free_and_bound_symbols():
a1 = ps.Assignment(a, d[0, 0](0))
a2 = ps.Assignment(f[0, 0](1), b * c)
ac = ps.AssignmentCollection([a2], subexpressions=[a1])
assert f[0, 0](1) in ac.bound_symbols
assert d[0, 0](0) in ac.free_symbols
def test_new_merged():
a1 = ps.Assignment(a, b * c)
a2 = ps.Assignment(a, x * y)
a3 = ps.Assignment(t, x ** 2)
# main assignments
a4 = ps.Assignment(f[0, 0](0), a)
a5 = ps.Assignment(d[0, 0](0), a)
ac = ps.AssignmentCollection([a4], subexpressions=[a1])
ac2 = ps.AssignmentCollection([a5], subexpressions=[a2, a3])
merged_ac = ac.new_merged(ac2)
assert len(merged_ac.subexpressions) == 3
assert len(merged_ac.main_assignments) == 2
assert ps.Assignment(sp.symbols('xi_0'), x * y) in merged_ac.subexpressions
assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments
assert a1 in merged_ac.subexpressions
assert a3 in merged_ac.subexpressions
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment