Commit 38da1c39 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix #17: Allow vector assignments

parent f16225d5
Pipeline #21149 passed with stage
in 2 minutes and 25 seconds
# -*- coding: utf-8 -*-
import numpy as np
import sympy as sp
from sympy.printing.latex import LatexPrinter
......@@ -24,9 +23,19 @@ def assignment_str(assignment):
if Assignment:
_old_new = sp.codegen.ast.Assignment.__new__
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)):
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
else:
# back port for older sympy versions that don't have Assignment yet
......
import itertools
from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
......@@ -43,6 +44,11 @@ class AssignmentCollection:
subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments
self.subexpressions = subexpressions
......
......@@ -40,3 +40,28 @@ def test_free_and_defined_symbols():
print(ac)
print(ac.__repr__)
def test_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3]))
print(assignments)
def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3])
assignments = ps.AssignmentCollection({y: x})
print(assignments)
assignments = ps.AssignmentCollection([ps.Assignment(y,x)])
print(assignments)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment