From 10f8f119784caff63fee557fa9f1fa867a60cfa6 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 21 Jan 2022 12:12:28 +0100
Subject: [PATCH] Fix indexed kernels

---
 pystencils/backends/cbackend.py          |  2 +-
 pystencils/field.py                      | 14 +++++++++++---
 pystencils/typing/cast_functions.py      |  5 +++--
 pystencils/typing/leaf_typing.py         |  1 +
 pystencils_tests/test_indexed_kernels.py |  2 ++
 pystencils_tests/test_random.py          | 13 ++++++++-----
 6 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index a8a3dbc8d..3057f6870 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -470,7 +470,7 @@ class CustomSympyPrinter(CCodePrinter):
         else:
             return f'fabs({self._print(expr.args[0])})'
 
-    def _print_Type(self, node):
+    def _print_AbstractType(self, node):
         return str(node)
 
     def _print_Function(self, expr):
diff --git a/pystencils/field.py b/pystencils/field.py
index 4a29a1be2..ceaeb82ac 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -13,7 +13,7 @@ from sympy.core.cache import cacheit
 
 import pystencils
 from pystencils.alignedarray import aligned_empty
-from pystencils.typing import StructType, TypedSymbol, create_type
+from pystencils.typing import StructType, TypedSymbol, BasicType, create_type
 from pystencils.typing.typed_sympy import FieldShapeSymbol, FieldStrideSymbol
 from pystencils.stencil import (
     direction_string_to_offset, inverse_direction, offset_to_direction_string)
@@ -673,7 +673,11 @@ class Field:
             if superscript is not None:
                 symbol_name += "^" + superscript
 
-            obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
+            if dtype:
+                obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype)
+            else:
+                obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
+
             obj._field = field
             obj._offsets = []
             for o in offsets:
@@ -716,7 +720,11 @@ class Field:
 
             if len(idx) != self.field.index_dimensions:
                 raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}")
-            return Field.Access(self.field, self._offsets, idx, dtype=self.dtype)
+            if len(idx) == 1 and isinstance(idx[0], str):
+                dtype = BasicType(self.field.dtype.numpy_dtype[idx[0]])
+                return Field.Access(self.field, self._offsets, idx, dtype=dtype)
+            else:
+                return Field.Access(self.field, self._offsets, idx, dtype=self.dtype)
 
         def __getitem__(self, *idx):
             return self.__call__(*idx)
diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py
index b70c8bb48..8200e9697 100644
--- a/pystencils/typing/cast_functions.py
+++ b/pystencils/typing/cast_functions.py
@@ -15,8 +15,9 @@ class CastFunc(sp.Function):
             pass
         expr, dtype, *other_args = args
 
-        # If we have two consecutive casts, throw the inner one away
-        if isinstance(expr, CastFunc):
+        # If we have two consecutive casts, throw the inner one away.
+        # This optimisation is only available for simple casts. Thus the == is intended here!
+        if expr.__class__ == CastFunc:
             expr = expr.args[0]
         if not isinstance(dtype, AbstractType):
             dtype = BasicType(dtype)
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index f92a0db73..b620c9c7e 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -113,6 +113,7 @@ class TypeAdder:
 
         # TOOO: check the access
         if isinstance(expr, Field.Access):
+            # TODO if Struct, look at the reinterpreted dtype
             return expr, expr.dtype
         elif isinstance(expr, TypedSymbol):
             return expr, expr.dtype
diff --git a/pystencils_tests/test_indexed_kernels.py b/pystencils_tests/test_indexed_kernels.py
index 87b24b354..fa06a8f16 100644
--- a/pystencils_tests/test_indexed_kernels.py
+++ b/pystencils_tests/test_indexed_kernels.py
@@ -1,4 +1,5 @@
 import numpy as np
+import pystencils as ps
 from pystencils import Assignment, Field, CreateKernelConfig, create_kernel, Target
 
 
@@ -18,6 +19,7 @@ def test_indexed_kernel():
     ast = create_kernel([update_rule], config=config)
     kernel = ast.compile()
     kernel(f=arr, index=index_arr)
+    code = ps.get_code_str(kernel)
     for i in range(index_arr.shape[0]):
         np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13)
 
diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py
index 2310f4c1d..5d839da2f 100644
--- a/pystencils_tests/test_random.py
+++ b/pystencils_tests/test_random.py
@@ -4,6 +4,7 @@ import pytest
 
 import pystencils as ps
 from pystencils.astnodes import SympyAssignment
+from pystencils.node_collection import NodeCollection
 from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
 from pystencils.cpu.cpujit import get_compiler_config
@@ -163,13 +164,15 @@ def test_rng_symbol(vectorized):
 
     dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target=Target.CPU)
     f = dh.add_array("f", values_per_cell=2 * dh.dim, alignment=True)
-    ac = ps.AssignmentCollection([ps.Assignment(f(i), 0) for i in range(f.shape[-1])])
-    rng_symbol_gen = random_symbol(ac.subexpressions, dim=dh.dim)
+    nc = NodeCollection([SympyAssignment(f(i), 0) for i in range(f.shape[-1])])
+    subexpressions = []
+    rng_symbol_gen = random_symbol(subexpressions, dim=dh.dim)
     for i in range(f.shape[-1]):
-        ac.main_assignments[i] = ps.Assignment(ac.main_assignments[i].lhs, next(rng_symbol_gen))
-    symbols = [a.rhs for a in ac.main_assignments]
+        nc.all_assignments[i] = SympyAssignment(nc.all_assignments[i].lhs, next(rng_symbol_gen))
+    symbols = [a.rhs for a in nc.all_assignments]
+    [nc.all_assignments.insert(0, subexpression) for subexpression in subexpressions]
     assert len(symbols) == f.shape[-1] and len(set(symbols)) == f.shape[-1]
-    ps.create_kernel(ac, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile()
+    ps.create_kernel(nc, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile()
 
 
 @pytest.mark.parametrize('vectorized', (False, True))
-- 
GitLab