diff --git a/pystencils/assignment.py b/pystencils/assignment.py index 0bf68799491be29886d215d5ff76010c034bd174..2b0ef06d859f1218fbf9c861981f947457bf4dfd 100644 --- a/pystencils/assignment.py +++ b/pystencils/assignment.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import numpy as np import sympy as sp from sympy.printing.latex import LatexPrinter @@ -24,9 +23,19 @@ def assignment_str(assignment): if Assignment: + _old_new = sp.codegen.ast.Assignment.__new__ + + def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): + if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)): + return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) + return _old_new(cls, lhs, rhs, *args, **kwargs) + Assignment.__str__ = assignment_str + Assignment.__new__ = _Assignment__new__ LatexPrinter._print_Assignment = print_assignment_latex + sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) + else: # back port for older sympy versions that don't have Assignment yet diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 3a8bb2bd3b315e23fb9e11ef6385f264428c7bd5..22968eb72361b15dd24a0ebc72fd284ddeddd2b7 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -1,3 +1,4 @@ +import itertools from copy import copy from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union @@ -43,6 +44,11 @@ class AssignmentCollection: subexpressions = [Assignment(k, v) for k, v in subexpressions.items()] + main_assignments = list(itertools.chain.from_iterable( + [(a if isinstance(a, Iterable) else [a]) for a in main_assignments])) + subexpressions = list(itertools.chain.from_iterable( + [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) + self.main_assignments = main_assignments self.subexpressions = subexpressions diff --git a/pystencils_tests/test_assignment_collection.py b/pystencils_tests/test_assignment_collection.py index 42e3791ab331cc111343bfe410bc59e3f52ba80b..f9bd8d55caba264af84b9e9be3d735ef34bf2c52 100644 --- a/pystencils_tests/test_assignment_collection.py +++ b/pystencils_tests/test_assignment_collection.py @@ -40,3 +40,28 @@ def test_free_and_defined_symbols(): print(ac) print(ac.__repr__) + + +def test_vector_assignments(): + """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" + + import pystencils as ps + import sympy as sp + a, b, c = sp.symbols("a b c") + assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3])) + print(assignments) + + +def test_vector_assignment_collection(): + """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" + + import pystencils as ps + import sympy as sp + a, b, c = sp.symbols("a b c") + y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3]) + assignments = ps.AssignmentCollection({y: x}) + print(assignments) + + assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) + print(assignments) +