From 144f46edd272bef44c39dbf5add20cf94624c484 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 3 May 2019 10:21:48 +0200
Subject: [PATCH] Fix: public sp.Indexed instead of sp.tensor.Indexed

---
 pystencils/assignment.py      | 3 +--
 pystencils/astnodes.py        | 5 ++---
 pystencils/sympyextensions.py | 2 +-
 pystencils/transformations.py | 6 +++---
 4 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/pystencils/assignment.py b/pystencils/assignment.py
index 9900a24e0..757413abd 100644
--- a/pystencils/assignment.py
+++ b/pystencils/assignment.py
@@ -38,11 +38,10 @@ else:
         def __new__(cls, lhs, rhs=0, **assumptions):
             from sympy.matrices.expressions.matexpr import (
                 MatrixElement, MatrixSymbol)
-            from sympy.tensor.indexed import Indexed
             lhs = sp.sympify(lhs)
             rhs = sp.sympify(rhs)
             # Tuple of things that can be on the lhs of an assignment
-            assignable = (sp.Symbol, MatrixSymbol, MatrixElement, Indexed)
+            assignable = (sp.Symbol, MatrixSymbol, MatrixElement, sp.Indexed)
             if not isinstance(lhs, assignable):
                 raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
             return sp.Rel.__new__(cls, lhs, rhs, **assumptions)
diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 727539373..c08c13789 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -1,5 +1,4 @@
 import sympy as sp
-from sympy.tensor import IndexedBase
 from pystencils.field import Field
 from pystencils.data_types import TypedSymbol, create_type, cast_func
 from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol
@@ -543,8 +542,8 @@ class SympyAssignment(Node):
 
 class ResolvedFieldAccess(sp.Indexed):
     def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
-        if not isinstance(base, IndexedBase):
-            base = IndexedBase(base, shape=(1,))
+        if not isinstance(base, sp.IndexedBase):
+            base = sp.IndexedBase(base, shape=(1,))
         obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
         obj.field = field
         obj.offsets = offsets
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 86614a212..6df8eedab 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -486,7 +486,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
             pass
         elif isinstance(t, sp.Symbol):
             visit_children = False
-        elif isinstance(t, sp.tensor.Indexed):
+        elif isinstance(t, sp.Indexed):
             visit_children = False
         elif t.is_integer:
             pass
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 4490e3e1a..8f05b20f4 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -6,7 +6,6 @@ import pickle
 import hashlib
 import sympy as sp
 from sympy.logic.boolalg import Boolean
-from sympy.tensor import IndexedBase
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.assignment import Assignment
 from pystencils.field import AbstractField, FieldType, Field
@@ -663,7 +662,8 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
             if not isinstance(symbol, AbstractField.AbstractAccess):
                 assert type(symbol) is TypedSymbol
                 new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
-                symbols_with_temporary_array[symbol] = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
+                symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts,
+                                                                      shape=(1,))[inner_loop.loop_counter_symbol]
 
         assignment_group = []
         for assignment in inner_loop.body.args:
@@ -672,7 +672,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
                 if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
                     assert type(assignment.lhs) is TypedSymbol
                     new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
-                    new_lhs = IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
+                    new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
                 else:
                     new_lhs = assignment.lhs
                 assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
-- 
GitLab