diff --git a/pystencils/assignment.py b/pystencils/assignment.py index 2b0ef06d859f1218fbf9c861981f947457bf4dfd..92027f7773d2202765ab2ef09c5269654098060e 100644 --- a/pystencils/assignment.py +++ b/pystencils/assignment.py @@ -26,7 +26,8 @@ if Assignment: _old_new = sp.codegen.ast.Assignment.__new__ def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): - if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)): + if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): + assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) return _old_new(cls, lhs, rhs, *args, **kwargs) diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py index f9bd8d55caba264af84b9e9be3d735ef34bf2c52..a16d51db44bb722ee1a9da9a20ad4192d4f6188e 100644 --- a/pystencils_tests/test_assignment_collection.py +++ b/pystencils_tests/test_assignment_collection.py @@ -1,3 +1,4 @@ +import pytest import sympy as sp from pystencils import Assignment, AssignmentCollection @@ -52,6 +53,18 @@ def test_vector_assignments(): print(assignments) +def test_wrong_vector_assignments(): + """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" + + import pystencils as ps + import sympy as sp + a, b = sp.symbols("a b") + + with pytest.raises(AssertionError, + match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'): + ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3])) + + def test_vector_assignment_collection(): """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" @@ -64,4 +77,3 @@ def test_vector_assignment_collection(): assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) print(assignments) -