From 38da1c39b1a07d57746da5e17b592810d9e3f2e0 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sat, 18 Jan 2020 16:24:35 +0100
Subject: [PATCH] Fix #17: Allow vector assignments

---
 pystencils/assignment.py                      | 11 +++++++-
 pystencils/simp/assignment_collection.py      |  6 +++++
 .../test_assignment_collection.py             | 25 +++++++++++++++++++
 3 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/pystencils/assignment.py b/pystencils/assignment.py
index 0bf687994..2b0ef06d8 100644
--- a/pystencils/assignment.py
+++ b/pystencils/assignment.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 import numpy as np
 import sympy as sp
 from sympy.printing.latex import LatexPrinter
@@ -24,9 +23,19 @@ def assignment_str(assignment):
 
 if Assignment:
 
+    _old_new = sp.codegen.ast.Assignment.__new__
+
+    def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
+        if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)):
+            return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
+        return _old_new(cls, lhs, rhs, *args, **kwargs)
+
     Assignment.__str__ = assignment_str
+    Assignment.__new__ = _Assignment__new__
     LatexPrinter._print_Assignment = print_assignment_latex
 
+    sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
+
 else:
     # back port for older sympy versions that don't have Assignment  yet
 
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 3a8bb2bd3..22968eb72 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -1,3 +1,4 @@
+import itertools
 from copy import copy
 from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
 
@@ -43,6 +44,11 @@ class AssignmentCollection:
             subexpressions = [Assignment(k, v)
                               for k, v in subexpressions.items()]
 
+        main_assignments = list(itertools.chain.from_iterable(
+            [(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
+        subexpressions = list(itertools.chain.from_iterable(
+            [(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
+
         self.main_assignments = main_assignments
         self.subexpressions = subexpressions
 
diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py
index 42e3791ab..f9bd8d55c 100644
--- a/pystencils_tests/test_assignment_collection.py
+++ b/pystencils_tests/test_assignment_collection.py
@@ -40,3 +40,28 @@ def test_free_and_defined_symbols():
 
     print(ac)
     print(ac.__repr__)
+
+
+def test_vector_assignments():
+    """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
+
+    import pystencils as ps
+    import sympy as sp
+    a, b, c = sp.symbols("a b c")
+    assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3]))
+    print(assignments)
+
+
+def test_vector_assignment_collection():
+    """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
+
+    import pystencils as ps
+    import sympy as sp
+    a, b, c = sp.symbols("a b c")
+    y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3])
+    assignments = ps.AssignmentCollection({y: x})
+    print(assignments)
+
+    assignments = ps.AssignmentCollection([ps.Assignment(y,x)])
+    print(assignments)
+
-- 
GitLab