Commit 144f46ed authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix: public sp.Indexed instead of sp.tensor.Indexed

parent e4084e94
...@@ -38,11 +38,10 @@ else: ...@@ -38,11 +38,10 @@ else:
def __new__(cls, lhs, rhs=0, **assumptions): def __new__(cls, lhs, rhs=0, **assumptions):
from sympy.matrices.expressions.matexpr import ( from sympy.matrices.expressions.matexpr import (
MatrixElement, MatrixSymbol) MatrixElement, MatrixSymbol)
from sympy.tensor.indexed import Indexed
lhs = sp.sympify(lhs) lhs = sp.sympify(lhs)
rhs = sp.sympify(rhs) rhs = sp.sympify(rhs)
# Tuple of things that can be on the lhs of an assignment # Tuple of things that can be on the lhs of an assignment
assignable = (sp.Symbol, MatrixSymbol, MatrixElement, Indexed) assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
if not isinstance(lhs, assignable): if not isinstance(lhs, assignable):
raise TypeError("Cannot assign to lhs of type %s." % type(lhs)) raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
return sp.Rel.__new__(cls, lhs, rhs, **assumptions) return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
......
import sympy as sp import sympy as sp
from sympy.tensor import IndexedBase
from pystencils.field import Field from pystencils.field import Field
from pystencils.data_types import TypedSymbol, create_type, cast_func from pystencils.data_types import TypedSymbol, create_type, cast_func
from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol
...@@ -543,8 +542,8 @@ class SympyAssignment(Node): ...@@ -543,8 +542,8 @@ class SympyAssignment(Node):
class ResolvedFieldAccess(sp.Indexed): class ResolvedFieldAccess(sp.Indexed):
def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values): def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
if not isinstance(base, IndexedBase): if not isinstance(base, sp.IndexedBase):
base = IndexedBase(base, shape=(1,)) base = sp.IndexedBase(base, shape=(1,))
obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index) obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
obj.field = field obj.field = field
obj.offsets = offsets obj.offsets = offsets
......
...@@ -486,7 +486,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -486,7 +486,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
pass pass
elif isinstance(t, sp.Symbol): elif isinstance(t, sp.Symbol):
visit_children = False visit_children = False
elif isinstance(t, sp.tensor.Indexed): elif isinstance(t, sp.Indexed):
visit_children = False visit_children = False
elif t.is_integer: elif t.is_integer:
pass pass
......
...@@ -6,7 +6,6 @@ import pickle ...@@ -6,7 +6,6 @@ import pickle
import hashlib import hashlib
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.field import AbstractField, FieldType, Field from pystencils.field import AbstractField, FieldType, Field
...@@ -663,7 +662,8 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -663,7 +662,8 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if not isinstance(symbol, AbstractField.AbstractAccess): if not isinstance(symbol, AbstractField.AbstractAccess):
assert type(symbol) is TypedSymbol assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts,
shape=(1,))[inner_loop.loop_counter_symbol]
assignment_group = [] assignment_group = []
for assignment in inner_loop.body.args: for assignment in inner_loop.body.args:
...@@ -672,7 +672,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -672,7 +672,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
else: else:
new_lhs = assignment.lhs new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs)) assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment