Commit 144f46ed authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix: public sp.Indexed instead of sp.tensor.Indexed

parent e4084e94
......@@ -38,11 +38,10 @@ else:
def __new__(cls, lhs, rhs=0, **assumptions):
from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol)
from sympy.tensor.indexed import Indexed
lhs = sp.sympify(lhs)
rhs = sp.sympify(rhs)
# Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, Indexed)
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
if not isinstance(lhs, assignable):
raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
......
import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, create_type, cast_func
from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol
......@@ -543,8 +542,8 @@ class SympyAssignment(Node):
class ResolvedFieldAccess(sp.Indexed):
def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
if not isinstance(base, IndexedBase):
base = IndexedBase(base, shape=(1,))
if not isinstance(base, sp.IndexedBase):
base = sp.IndexedBase(base, shape=(1,))
obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
obj.field = field
obj.offsets = offsets
......
......@@ -486,7 +486,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
pass
elif isinstance(t, sp.Symbol):
visit_children = False
elif isinstance(t, sp.tensor.Indexed):
elif isinstance(t, sp.Indexed):
visit_children = False
elif t.is_integer:
pass
......
......@@ -6,7 +6,6 @@ import pickle
import hashlib
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment
from pystencils.field import AbstractField, FieldType, Field
......@@ -663,7 +662,8 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if not isinstance(symbol, AbstractField.AbstractAccess):
assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts,
shape=(1,))[inner_loop.loop_counter_symbol]
assignment_group = []
for assignment in inner_loop.body.args:
......@@ -672,7 +672,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
else:
new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment