From 1c57a059949dfa17b6107e50ff33e9414517232d Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sat, 18 Jan 2020 16:59:36 +0100
Subject: [PATCH] Assert same length when performing vector assignment

---
 pystencils/assignment.py                       |  3 ++-
 pystencils_tests/test_assignment_collection.py | 14 +++++++++++++-
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/pystencils/assignment.py b/pystencils/assignment.py
index 2b0ef06d8..92027f777 100644
--- a/pystencils/assignment.py
+++ b/pystencils/assignment.py
@@ -26,7 +26,8 @@ if Assignment:
     _old_new = sp.codegen.ast.Assignment.__new__
 
     def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
-        if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)):
+        if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
+            assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
             return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
         return _old_new(cls, lhs, rhs, *args, **kwargs)
 
diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py
index f9bd8d55c..a16d51db4 100644
--- a/pystencils_tests/test_assignment_collection.py
+++ b/pystencils_tests/test_assignment_collection.py
@@ -1,3 +1,4 @@
+import pytest
 import sympy as sp
 
 from pystencils import Assignment, AssignmentCollection
@@ -52,6 +53,18 @@ def test_vector_assignments():
     print(assignments)
 
 
+def test_wrong_vector_assignments():
+    """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
+
+    import pystencils as ps
+    import sympy as sp
+    a, b = sp.symbols("a b")
+    
+    with pytest.raises(AssertionError,
+            match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
+        ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3]))
+
+
 def test_vector_assignment_collection():
     """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
 
@@ -64,4 +77,3 @@ def test_vector_assignment_collection():
 
     assignments = ps.AssignmentCollection([ps.Assignment(y,x)])
     print(assignments)
-
-- 
GitLab