From 1c57a059949dfa17b6107e50ff33e9414517232d Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Sat, 18 Jan 2020 16:59:36 +0100 Subject: [PATCH] Assert same length when performing vector assignment --- pystencils/assignment.py | 3 ++- pystencils_tests/test_assignment_collection.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pystencils/assignment.py b/pystencils/assignment.py index 2b0ef06d8..92027f777 100644 --- a/pystencils/assignment.py +++ b/pystencils/assignment.py @@ -26,7 +26,8 @@ if Assignment: _old_new = sp.codegen.ast.Assignment.__new__ def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): - if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)): + if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): + assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) return _old_new(cls, lhs, rhs, *args, **kwargs) diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py index f9bd8d55c..a16d51db4 100644 --- a/pystencils_tests/test_assignment_collection.py +++ b/pystencils_tests/test_assignment_collection.py @@ -1,3 +1,4 @@ +import pytest import sympy as sp from pystencils import Assignment, AssignmentCollection @@ -52,6 +53,18 @@ def test_vector_assignments(): print(assignments) +def test_wrong_vector_assignments(): + """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" + + import pystencils as ps + import sympy as sp + a, b = sp.symbols("a b") + + with pytest.raises(AssertionError, + match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'): + ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3])) + + def test_vector_assignment_collection(): """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" @@ -64,4 +77,3 @@ def test_vector_assignment_collection(): assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) print(assignments) - -- GitLab