diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 696038dd59f7faad83e74dbeecdfcfb038ea127a..6bd1c66021c129e35288101c24cdcca96fcebd46 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -6,8 +6,7 @@ import sympy as sp import pystencils from pystencils.assignment import Assignment -from pystencils.simp.simplifications import ( - sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) +from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) from pystencils.sympyextensions import count_operations, fast_subs @@ -263,7 +262,7 @@ class AssignmentCollection: own_definitions = set([e.lhs for e in self.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments]) assert len(own_definitions.intersection(other_definitions)) == 0, \ - "Cannot new_merged, since both collection define the same symbols" + "Cannot merge collections, since both define the same symbols" own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} substitution_dict = {} @@ -334,7 +333,7 @@ class AssignmentCollection: kept_subexpressions = [] if self.subexpressions[0].lhs in subexpressions_to_keep: substitution_dict = {} - kept_subexpressions = self.subexpressions[0] + kept_subexpressions.append(self.subexpressions[0]) else: substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py index a16d51db44bb722ee1a9da9a20ad4192d4f6188e..f0c1f2a91e93739c1f40d8e2223641bd2f8b9bbb 100644 --- a/pystencils_tests/test_assignment_collection.py +++ b/pystencils_tests/test_assignment_collection.py @@ -1,15 +1,19 @@ import pytest import sympy as sp +import pystencils as ps from pystencils import Assignment, AssignmentCollection from pystencils.astnodes import Conditional from pystencils.simp.assignment_collection import SymbolGen +a, b, c = sp.symbols("a b c") +x, y, z, t = sp.symbols("x y z t") +symbol_gen = SymbolGen("a") +f = ps.fields("f(2) : [2D]") +d = ps.fields("d(2) : [2D]") -def test_assignment_collection(): - x, y, z, t = sp.symbols("x y z t") - symbol_gen = SymbolGen("a") +def test_assignment_collection(): ac = AssignmentCollection([Assignment(z, x + y)], [], subexpression_symbol_generator=symbol_gen) @@ -32,10 +36,6 @@ def test_assignment_collection(): def test_free_and_defined_symbols(): - x, y, z, t = sp.symbols("x y z t") - a, b = sp.symbols("a b") - symbol_gen = SymbolGen("a") - ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))], [], subexpression_symbol_generator=symbol_gen) @@ -45,35 +45,128 @@ def test_free_and_defined_symbols(): def test_vector_assignments(): """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" - - import pystencils as ps - import sympy as sp - a, b, c = sp.symbols("a b c") - assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3])) + assignments = ps.Assignment(sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3])) print(assignments) def test_wrong_vector_assignments(): """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" - - import pystencils as ps - import sympy as sp - a, b = sp.symbols("a b") - with pytest.raises(AssertionError, - match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'): - ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3])) + match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'): + ps.Assignment(sp.Matrix([a, b]), sp.Matrix([1, 2, 3])) def test_vector_assignment_collection(): """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" - import pystencils as ps - import sympy as sp - a, b, c = sp.symbols("a b c") - y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3]) - assignments = ps.AssignmentCollection({y: x}) + y_m, x_m = sp.Matrix([a, b, c]), sp.Matrix([1, 2, 3]) + assignments = ps.AssignmentCollection({y_m: x_m}) print(assignments) - assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) + assignments = ps.AssignmentCollection([ps.Assignment(y_m, x_m)]) print(assignments) + + +def test_new_with_substitutions(): + a1 = ps.Assignment(f[0, 0](0), a * b) + a2 = ps.Assignment(f[0, 0](1), b * c) + + ac = ps.AssignmentCollection([a1, a2], subexpressions=[]) + subs_dict = {f[0, 0](0): d[0, 0](0), f[0, 0](1): d[0, 0](1)} + subs_ac = ac.new_with_substitutions(subs_dict, + add_substitutions_as_subexpressions=False, + substitute_on_lhs=True, + sort_topologically=True) + + assert subs_ac.main_assignments[0].lhs == d[0, 0](0) + assert subs_ac.main_assignments[1].lhs == d[0, 0](1) + + subs_ac = ac.new_with_substitutions(subs_dict, + add_substitutions_as_subexpressions=False, + substitute_on_lhs=False, + sort_topologically=True) + + assert subs_ac.main_assignments[0].lhs == f[0, 0](0) + assert subs_ac.main_assignments[1].lhs == f[0, 0](1) + + subs_dict = {a * b: sp.symbols('xi')} + subs_ac = ac.new_with_substitutions(subs_dict, + add_substitutions_as_subexpressions=False, + substitute_on_lhs=False, + sort_topologically=True) + + assert subs_ac.main_assignments[0].rhs == sp.symbols('xi') + assert len(subs_ac.subexpressions) == 0 + + subs_ac = ac.new_with_substitutions(subs_dict, + add_substitutions_as_subexpressions=True, + substitute_on_lhs=False, + sort_topologically=True) + + assert subs_ac.main_assignments[0].rhs == sp.symbols('xi') + assert len(subs_ac.subexpressions) == 1 + assert subs_ac.subexpressions[0].lhs == sp.symbols('xi') + + +def test_copy(): + a1 = ps.Assignment(f[0, 0](0), a * b) + a2 = ps.Assignment(f[0, 0](1), b * c) + + ac = ps.AssignmentCollection([a1, a2], subexpressions=[]) + ac2 = ac.copy() + assert ac2 == ac + + +def test_set_expressions(): + a1 = ps.Assignment(f[0, 0](0), a * b) + a2 = ps.Assignment(f[0, 0](1), b * c) + + ac = ps.AssignmentCollection([a1, a2], subexpressions=[]) + + ac.set_main_assignments_from_dict({d[0, 0](0): b * c}) + assert len(ac.main_assignments) == 1 + assert ac.main_assignments[0] == ps.Assignment(d[0, 0](0), b * c) + + ac.set_sub_expressions_from_dict({sp.symbols('xi'): a * b}) + assert len(ac.subexpressions) == 1 + assert ac.subexpressions[0] == ps.Assignment(sp.symbols('xi'), a * b) + + ac = ac.new_without_subexpressions(subexpressions_to_keep={sp.symbols('xi')}) + assert ac.subexpressions[0] == ps.Assignment(sp.symbols('xi'), a * b) + + ac = ac.new_without_unused_subexpressions() + assert len(ac.subexpressions) == 0 + + ac2 = ac.new_without_subexpressions() + assert ac == ac2 + + +def test_free_and_bound_symbols(): + a1 = ps.Assignment(a, d[0, 0](0)) + a2 = ps.Assignment(f[0, 0](1), b * c) + + ac = ps.AssignmentCollection([a2], subexpressions=[a1]) + assert f[0, 0](1) in ac.bound_symbols + assert d[0, 0](0) in ac.free_symbols + + +def test_new_merged(): + a1 = ps.Assignment(a, b * c) + a2 = ps.Assignment(a, x * y) + a3 = ps.Assignment(t, x ** 2) + + # main assignments + a4 = ps.Assignment(f[0, 0](0), a) + a5 = ps.Assignment(d[0, 0](0), a) + + ac = ps.AssignmentCollection([a4], subexpressions=[a1]) + ac2 = ps.AssignmentCollection([a5], subexpressions=[a2, a3]) + + merged_ac = ac.new_merged(ac2) + + assert len(merged_ac.subexpressions) == 3 + assert len(merged_ac.main_assignments) == 2 + assert ps.Assignment(sp.symbols('xi_0'), x * y) in merged_ac.subexpressions + assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments + assert a1 in merged_ac.subexpressions + assert a3 in merged_ac.subexpressions