Commit 46b7b88d authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'vector-assignments' into 'master'

Allow vector assignments

Closes #17

See merge request pycodegen/pystencils!133
parents 75b3dad8 1c57a059
# -*- coding: utf-8 -*-
import numpy as np
import sympy as sp
from sympy.printing.latex import LatexPrinter
......@@ -24,9 +23,20 @@ def assignment_str(assignment):
if Assignment:
_old_new = sp.codegen.ast.Assignment.__new__
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
else:
# back port for older sympy versions that don't have Assignment yet
......
import itertools
from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
......@@ -43,6 +44,11 @@ class AssignmentCollection:
subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments
self.subexpressions = subexpressions
......
import pytest
import sympy as sp
from pystencils import Assignment, AssignmentCollection
......@@ -40,3 +41,39 @@ def test_free_and_defined_symbols():
print(ac)
print(ac.__repr__)
def test_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3]))
print(assignments)
def test_wrong_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b = sp.symbols("a b")
with pytest.raises(AssertionError,
match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3]))
def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3])
assignments = ps.AssignmentCollection({y: x})
print(assignments)
assignments = ps.AssignmentCollection([ps.Assignment(y,x)])
print(assignments)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment