From adb06363168b7dc3002134557a81f60842c7352f Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Tue, 17 Oct 2023 10:40:22 +0200
Subject: [PATCH] Indexed domain kernel

---
 pystencils/astnodes.py                   | 23 ++++---
 pystencils/backends/cbackend.py          | 17 ++++--
 pystencils/cpu/kernelcreation.py         |  7 ++-
 pystencils/field.py                      |  7 +--
 pystencils/gpu/indexing.py               |  7 ++-
 pystencils/gpu/kernelcreation.py         | 19 ++++--
 pystencils/node_collection.py            |  3 +
 pystencils/transformations.py            | 41 +++++++++++--
 pystencils/typing/types.py               | 24 +++++---
 pystencils_tests/test_indexed_kernels.py | 77 ++++++++++++++++--------
 10 files changed, 161 insertions(+), 64 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index f50ef65a9..3d1940e69 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -428,7 +428,7 @@ class LoopOverCoordinate(Node):
     LOOP_COUNTER_NAME_PREFIX = "ctr"
     BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
 
-    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
+    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, custom_loop_ctr=None):
         super(LoopOverCoordinate, self).__init__(parent=None)
         self.body = body
         body.parent = self
@@ -439,10 +439,11 @@ class LoopOverCoordinate(Node):
         self.body.parent = self
         self.prefix_lines = []
         self.is_block_loop = is_block_loop
+        self.custom_loop_ctr = custom_loop_ctr
 
     def new_loop_with_different_body(self, new_body):
         result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
-                                    self.step, self.is_block_loop)
+                                    self.step, self.is_block_loop, self.custom_loop_ctr)
         result.prefix_lines = [prefix_line for prefix_line in self.prefix_lines]
         return result
 
@@ -505,10 +506,13 @@ class LoopOverCoordinate(Node):
 
     @property
     def loop_counter_name(self):
-        if self.is_block_loop:
-            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
+        if self.custom_loop_ctr:
+            return self.custom_loop_ctr.name
         else:
-            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
+            if self.is_block_loop:
+                return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
+            else:
+                return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
 
     @staticmethod
     def is_loop_counter_symbol(symbol):
@@ -532,10 +536,13 @@ class LoopOverCoordinate(Node):
 
     @property
     def loop_counter_symbol(self):
-        if self.is_block_loop:
-            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
+        if self.custom_loop_ctr:
+            return self.custom_loop_ctr
         else:
-            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
+            if self.is_block_loop:
+                return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
+            else:
+                return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
 
     @property
     def is_outermost_loop(self):
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index dab3d50c6..0ee8d0e43 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -248,12 +248,13 @@ class CBackend:
         return f"{node.pragma_line}\n{self._print_Block(node)}"
 
     def _print_LoopOverCoordinate(self, node):
-        counter_symbol = node.loop_counter_name
-        start = f"int64_t {counter_symbol} = {self.sympy_printer.doprint(node.start)}"
-        condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}"
-        update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}"
+        counter_name = node.loop_counter_name
+        counter_dtype = node.loop_counter_symbol.dtype.c_name
+        start = f"{counter_dtype} {counter_name} = {self.sympy_printer.doprint(node.start)}"
+        condition = f"{counter_name} < {self.sympy_printer.doprint(node.stop)}"
+        update = f"{counter_name} += {self.sympy_printer.doprint(node.step)}"
         loop_str = f"for ({start}; {condition}; {update})"
-        self._kwargs['loop_counter'] = counter_symbol
+        self._kwargs['loop_counter'] = counter_name
         self._kwargs['loop_stop'] = node.stop
 
         prefix = "\n".join(node.prefix_lines)
@@ -497,7 +498,11 @@ class CustomSympyPrinter(CCodePrinter):
             return expr.to_c(self._print)
         if isinstance(expr, ReinterpretCastFunc):
             arg, data_type = expr.args
-            return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
+            if isinstance(data_type, PointerType):
+                const_str = "const" if data_type.const else ""
+                return f"(({const_str} {self._print(data_type.base_type)} *)(& {self._print(arg)}))"
+            else:
+                return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
         elif isinstance(expr, AddressOf):
             assert len(expr.args) == 1, "address_of must only have one argument"
             return f"&({self._print(expr.args[0])})"
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index 608f6bc9a..14e4ea3b7 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -11,7 +11,8 @@ from pystencils.field import Field, FieldType
 from pystencils.node_collection import NodeCollection
 from pystencils.transformations import (
     filtered_tree_iteration, iterate_loops_by_depth, get_base_buffer_index, get_optimal_loop_ordering,
-    make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses,
+    make_loop_over_domain, add_outer_loop_over_indexed_elements,
+    move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses,
     resolve_field_accesses, split_inner_loop)
 
 
@@ -53,6 +54,8 @@ def create_kernel(assignments: NodeCollection,
     loop_order = get_optimal_loop_ordering(fields_without_buffers)
     loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
                                                         ghost_layers=ghost_layers, loop_order=loop_order)
+    loop_node = add_outer_loop_over_indexed_elements(loop_node)
+
     ast_node = KernelFunction(loop_node, Target.CPU, Backend.C, compile_function=make_python_function,
                               ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
 
@@ -219,7 +222,7 @@ def add_pragmas(ast_node, pragma_lines, nesting_depth=-1):
     """Prepends given pragma lines to all loops of specified nesting depth.
     
     Args:
-        ast: pystencils abstract syntax tree
+        ast_node: pystencils abstract syntax tree
         pragma_lines: Iterable of strings containing the pragma lines
         nesting_depth: Nesting depth of the loops the pragmas should be applied to.
                        Outermost loop has depth 0.
diff --git a/pystencils/field.py b/pystencils/field.py
index 33d269bef..b4c040e53 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -256,9 +256,7 @@ class Field:
         self.shape = shape
         self.strides = strides
         self.latex_name: Optional[str] = None
-        self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple(
-            0 for _ in range(self.spatial_dimensions)
-        ))
+        self.coordinate_origin = sp.Matrix([0] * self.spatial_dimensions)
         self.coordinate_transform = sp.eye(self.spatial_dimensions)
         if field_type == FieldType.STAGGERED:
             assert self.staggered_stencil
@@ -267,8 +265,7 @@ class Field:
         if self.has_fixed_shape:
             return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
         else:
-            return Field.create_generic(new_name, self.spatial_dimensions, self.dtype.numpy_dtype,
-                                        self.index_dimensions, self._layout, self.index_shape, self.field_type)
+            return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides)
 
     @property
     def spatial_dimensions(self) -> int:
diff --git a/pystencils/gpu/indexing.py b/pystencils/gpu/indexing.py
index a52cf2a3f..843e77bb8 100644
--- a/pystencils/gpu/indexing.py
+++ b/pystencils/gpu/indexing.py
@@ -217,7 +217,12 @@ class BlockIndexing(AbstractIndexing):
 
     def guard(self, kernel_content, arr_shape):
         arr_shape = arr_shape[:self._dim]
-        numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
+        if len(self._iteration_space) - 1 == len(arr_shape):
+            numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space[1:], arr_shape)
+            numeric_iteration_slice = [self.iteration_space[0]] + numeric_iteration_slice
+        else:
+            assert len(self._iteration_space) == len(arr_shape), "Iteration space must be equal to the array shape"
+            numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
         end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice]
 
         if self._dim < 4:
diff --git a/pystencils/gpu/kernelcreation.py b/pystencils/gpu/kernelcreation.py
index f819b8f80..c2e6143bc 100644
--- a/pystencils/gpu/kernelcreation.py
+++ b/pystencils/gpu/kernelcreation.py
@@ -1,3 +1,5 @@
+import sympy as sp
+
 from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
 from pystencils.config import CreateKernelConfig
 from pystencils.typing import StructType, TypedSymbol
@@ -9,7 +11,7 @@ from pystencils.node_collection import NodeCollection
 from pystencils.gpu.indexing import indexing_creator_from_params
 from pystencils.slicing import normalize_slice
 from pystencils.transformations import (
-    get_base_buffer_index, get_common_field, parse_base_pointer_info,
+    get_base_buffer_index, get_common_field, get_common_indexed_element, parse_base_pointer_info,
     resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
 
 
@@ -34,7 +36,9 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
 
     field_accesses = set()
     num_buffer_accesses = 0
+    indexed_elements = set()
     for eq in assignments:
+        indexed_elements.update(eq.atoms(sp.Indexed))
         field_accesses.update(eq.atoms(Field.Access))
         field_accesses = {e for e in field_accesses if not e.is_absolute_access}
         num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field))
@@ -62,12 +66,19 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
         iteration_space = normalize_slice(iteration_slice, common_shape)
     else:
         iteration_space = normalize_slice(iteration_slice, common_shape)
-
     iteration_space = tuple([s if isinstance(s, slice) else slice(s, s, 1) for s in iteration_space])
+
     loop_counter_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(len(iteration_space))]
 
-    indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
-    loop_counter_assignments = indexing.get_loop_ctr_assignments(loop_counter_symbols)
+    if len(indexed_elements) > 0:
+        common_indexed_element = get_common_indexed_element(indexed_elements)
+        indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space),
+                                    data_layout=common_field.layout)
+        extended_ctrs = [common_indexed_element.indices[0], *loop_counter_symbols]
+        loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs)
+    else:
+        indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
+        loop_counter_assignments = indexing.get_loop_ctr_assignments(loop_counter_symbols)
     assignments = loop_counter_assignments + assignments
     block = indexing.guard(Block(assignments), common_shape)
 
diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py
index e0af05fd0..61a2f400a 100644
--- a/pystencils/node_collection.py
+++ b/pystencils/node_collection.py
@@ -9,6 +9,7 @@ import pystencils.astnodes as ast
 from pystencils.backends.cbackend import CustomCodeNode
 from pystencils.functions import DivFunc
 from pystencils.simp import AssignmentCollection
+from pystencils.typing import FieldPointerSymbol
 
 
 class NodeCollection:
@@ -20,6 +21,8 @@ class NodeCollection:
             if isinstance(obj, (list, tuple)):
                 return [visit(e) for e in obj]
             if isinstance(obj, Assignment):
+                if isinstance(obj.lhs, FieldPointerSymbol):
+                    return ast.SympyAssignment(obj.lhs, obj.rhs, is_const=obj.lhs.dtype.const)
                 return ast.SympyAssignment(obj.lhs, obj.rhs)
             elif isinstance(obj, AddAugmentedAssignment):
                 return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs)
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index d29a342fd..79c24d146 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -185,7 +185,7 @@ def get_common_field(field_set):
             raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
 
     # Sort the fields by their name to ensure that always the same field is returned
-    reference_field = list(sorted(field_set, key=lambda e: str(e)))[0]
+    reference_field = sorted(field_set, key=lambda e: str(e))[0]
     return reference_field
 
 
@@ -261,6 +261,26 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
     return current_body, ghost_layers
 
 
+def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
+    assert len(indexed_elements) > 0, "indexed_elements can not be empty"
+    shape_set = {s.shape for s in indexed_elements}
+    if len(shape_set) != 1:
+        for shape in shape_set:
+            assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
+
+    return sorted(indexed_elements, key=lambda e: str(e))[0]
+
+
+def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
+    indexed_elements = loop_node.atoms(sp.Indexed)
+    if len(indexed_elements) == 0:
+        return loop_node
+    reference_element = get_common_indexed_element(indexed_elements)
+    new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
+                                      reference_element.shape[0], 1, custom_loop_ctr=reference_element.indices[0])
+    return ast.Block([new_loop])
+
+
 def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
     r"""
     Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
@@ -411,11 +431,22 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
         loop_counters = [loop.loop_counter_symbol for loop in loops]
         loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
 
-    actual_sizes = [int_div((s.stop - s.start), s.step)
-                    if s.step != 1 else s.stop - s.start for s in loop_iterations]
+    actual_sizes = list()
+    actual_steps = list()
+    for ctr, s in zip(loop_counters, loop_iterations):
+        if s.step != 1:
+            if (s.stop - s.start) % s.step == 0:
+                actual_sizes.append((s.stop - s.start) // s.step)
+            else:
+                actual_sizes.append(int_div((s.stop - s.start), s.step))
 
-    actual_steps = [int_div((ctr - s.start), s.step)
-                    if s.step != 1 else ctr - s.start for ctr, s in zip(loop_counters, loop_iterations)]
+            if (ctr - s.start) % s.step == 0:
+                actual_steps.append((ctr - s.start) // s.step)
+            else:
+                actual_steps.append(int_div((ctr - s.start), s.step))
+        else:
+            actual_sizes.append(s.stop - s.start)
+            actual_steps.append(ctr - s.start)
 
     field_accesses = ast_node.atoms(Field.Access)
     buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
index f0f9744a5..4d80daffa 100644
--- a/pystencils/typing/types.py
+++ b/pystencils/typing/types.py
@@ -189,16 +189,17 @@ class VectorType(AbstractType):
 
 
 class PointerType(AbstractType):
-    def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True):
+    def __init__(self, base_type: BasicType, const: bool = False, restrict: bool = True, double_pointer: bool = False):
         self._base_type = base_type
         self.const = const
         self.restrict = restrict
+        self.double_pointer = double_pointer
 
     def __getnewargs__(self):
-        return self.base_type, self.const, self.restrict
+        return self.base_type, self.const, self.restrict, self.double_pointer
 
     def __getnewargs_ex__(self):
-        return (self.base_type, self.const, self.restrict), {}
+        return (self.base_type, self.const, self.restrict, self.double_pointer), {}
 
     @property
     def alias(self):
@@ -210,16 +211,25 @@ class PointerType(AbstractType):
 
     @property
     def item_size(self):
-        return self.base_type.item_size
+        if self.double_pointer:
+            raise NotImplementedError("The item_size for double_pointer is not implemented")
+        else:
+            return self.base_type.item_size
 
     def __eq__(self, other):
         if not isinstance(other, PointerType):
             return False
         else:
-            return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
+            own = (self.base_type, self.const, self.restrict, self.double_pointer)
+            return own == (other.base_type, other.const, other.restrict, other.double_pointer)
 
     def __str__(self):
-        return f'{str(self.base_type)} * {"RESTRICT " if self.restrict else "" }{"const" if self.const else ""}'
+        restrict_str = "RESTRICT" if self.restrict else ""
+        const_str = "const" if self.const else ""
+        if self.double_pointer:
+            return f'{str(self.base_type)} ** {restrict_str} {const_str}'
+        else:
+            return f'{str(self.base_type)} * {restrict_str} {const_str}'
 
     def __repr__(self):
         return str(self)
@@ -228,7 +238,7 @@ class PointerType(AbstractType):
         return str(self)
 
     def __hash__(self):
-        return hash((self._base_type, self.const, self.restrict))
+        return hash((self._base_type, self.const, self.restrict, self.double_pointer))
 
 
 class StructType(AbstractType):
diff --git a/pystencils_tests/test_indexed_kernels.py b/pystencils_tests/test_indexed_kernels.py
index c8c88ec86..2c0738dcf 100644
--- a/pystencils_tests/test_indexed_kernels.py
+++ b/pystencils_tests/test_indexed_kernels.py
@@ -1,11 +1,19 @@
+import sympy as sp
 import numpy as np
 import pytest
 
 import pystencils as ps
 from pystencils import Assignment, Field, CreateKernelConfig, create_kernel, Target
+from pystencils.transformations import filtered_tree_iteration
+from pystencils.typing import BasicType, FieldPointerSymbol, PointerType, TypedSymbol
 
 
-def test_indexed_kernel():
+@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
+def test_indexed_kernel(target):
+    if target == Target.GPU:
+        pytest.importorskip("cupy")
+        import cupy as cp
+
     arr = np.zeros((3, 4))
     dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)])
     index_arr = np.zeros((3,), dtype=dtype)
@@ -17,38 +25,55 @@ def test_indexed_kernel():
     normal_field = Field.create_from_numpy_array('f', arr)
     update_rule = Assignment(normal_field[0, 0], indexed_field('value'))
 
-    config = CreateKernelConfig(index_fields=[indexed_field])
+    config = CreateKernelConfig(target=target, index_fields=[indexed_field])
     ast = create_kernel([update_rule], config=config)
     kernel = ast.compile()
-    kernel(f=arr, index=index_arr)
-    code = ps.get_code_str(kernel)
+
+    if target == Target.CPU:
+        kernel(f=arr, index=index_arr)
+    else:
+        gpu_arr = cp.asarray(arr)
+        gpu_index_arr = cp.ndarray(index_arr.shape, dtype=index_arr.dtype)
+        gpu_index_arr.set(index_arr)
+        kernel(f=gpu_arr, index=gpu_index_arr)
+        arr = gpu_arr.get()
     for i in range(index_arr.shape[0]):
         np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13)
 
 
-def test_indexed_gpu_kernel():
-    pytest.importorskip("cupy")
-    import cupy as cp
+@pytest.mark.parametrize('index_size', ("fixed", "variable"))
+@pytest.mark.parametrize('array_size', ("3D", "2D", "10, 12", "13, 17, 19"))
+@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
+@pytest.mark.parametrize('dtype', ("float64", "float32"))
+def test_indexed_domain_kernel(index_size, array_size, target, dtype):
+    dtype = BasicType(dtype)
 
-    arr = np.zeros((3, 4))
-    dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)])
-    index_arr = np.zeros((3,), dtype=dtype)
-    index_arr[0] = (0, 2, 3.0)
-    index_arr[1] = (1, 3, 42.0)
-    index_arr[2] = (2, 1, 5.0)
+    f = ps.fields(f'f(1): {dtype.numpy_dtype.name}[{array_size}]')
+    g = ps.fields(f'g(1): {dtype.numpy_dtype.name}[{array_size}]')
 
-    indexed_field = Field.create_from_numpy_array('index', index_arr)
-    normal_field = Field.create_from_numpy_array('f', arr)
-    update_rule = Assignment(normal_field[0, 0], indexed_field('value'))
+    index = TypedSymbol("index", dtype=BasicType(np.int16))
+    if index_size == "variable":
+        index_src = TypedSymbol("_size_src", dtype=BasicType(np.int16))
+        index_dst = TypedSymbol("_size_dst", dtype=BasicType(np.int16))
+    else:
+        index_src = 16
+        index_dst = 16
+    pointer_type = PointerType(dtype, const=False, restrict=True, double_pointer=True)
+    const_pointer_type = PointerType(dtype, const=True, restrict=True, double_pointer=True)
 
-    config = CreateKernelConfig(target=Target.GPU, index_fields=[indexed_field])
-    ast = create_kernel([update_rule], config=config)
-    kernel = ast.compile()
+    src = sp.IndexedBase(TypedSymbol(f"_data_{f.name}", dtype=const_pointer_type), shape=index_src)
+    dst = sp.IndexedBase(TypedSymbol(f"_data_{g.name}", dtype=pointer_type), shape=index_dst)
+
+    update_rule = [ps.Assignment(FieldPointerSymbol("f", dtype, const=True), src[index]),
+                   ps.Assignment(FieldPointerSymbol("g", dtype, const=False), dst[index]),
+                   ps.Assignment(g.center, f.center)]
+
+    ast = ps.create_kernel(update_rule, target=target)
+
+    code = ps.get_code_str(ast)
+    assert f"const {dtype.c_name} * RESTRICT _data_f = (({dtype.c_name} * RESTRICT const)(_data_f[index]));" in code
+    assert f"{dtype.c_name} * RESTRICT  _data_g = (({dtype.c_name} * RESTRICT )(_data_g[index]));" in code
+
+    if target == Target.CPU:
+        assert code.count("for") == f.spatial_dimensions + 1
 
-    gpu_arr = cp.asarray(arr)
-    gpu_index_arr = cp.ndarray(index_arr.shape, dtype=index_arr.dtype)
-    gpu_index_arr.set(index_arr)
-    kernel(f=gpu_arr, index=gpu_index_arr)
-    arr = gpu_arr.get()
-    for i in range(index_arr.shape[0]):
-        np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13)
-- 
GitLab