From f780b06678a1298927e79b9085531221f14c99bb Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 17 Mar 2023 13:53:06 +0100
Subject: [PATCH] Fix buffers for GPU

---
 pystencils/gpucuda/indexing.py       | 13 +++++++++++++
 pystencils/gpucuda/kernelcreation.py |  3 ++-
 pystencils/transformations.py        | 17 +++++++++--------
 3 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py
index 9e2df4e0a..85b506e90 100644
--- a/pystencils/gpucuda/indexing.py
+++ b/pystencils/gpucuda/indexing.py
@@ -177,6 +177,9 @@ class BlockIndexing(AbstractIndexing):
             condition = sp.And(condition, c)
         return Block([Conditional(condition, kernel_content)])
 
+    def iteration_space(self, arr_shape):
+        return _iteration_space(self._iterationSlice, arr_shape)
+
     @staticmethod
     def limit_block_size_by_register_restriction(block_size, required_registers_per_thread, device=None):
         """Shrinks the block_size if there are too many registers used per multiprocessor.
@@ -284,6 +287,9 @@ class LineIndexing(AbstractIndexing):
     def symbolic_parameters(self):
         return set()
 
+    def iteration_space(self, arr_shape):
+        return _iteration_space(self._iterationSlice, arr_shape)
+
 
 # -------------------------------------- Helper functions --------------------------------------------------------------
 
@@ -310,6 +316,13 @@ def _get_end_from_slice(iteration_slice, arr_shape):
     return res
 
 
+def _iteration_space(iteration_slice, arr_shape):
+    starts = _get_start_from_slice(iteration_slice)
+    ends = _get_end_from_slice(iteration_slice, arr_shape)
+    steps = [s.step for s in iteration_slice]
+    return [slice(start, end, step) for start, end, step in zip(starts, ends, steps)]
+
+
 def indexing_creator_from_params(gpu_indexing, gpu_indexing_params):
     if isinstance(gpu_indexing, str):
         if gpu_indexing == 'block':
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index 595cf8cb5..0aa265362 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -92,7 +92,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
     coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
 
     if any(FieldType.is_buffer(f) for f in all_fields):
-        resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, common_shape), read_only_fields)
+        iteration_space = indexing.iteration_space(common_shape)
+        resolve_buffer_accesses(ast, get_base_buffer_index(ast, cell_idx_symbols, iteration_space), read_only_fields)
 
     resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
                            field_to_fixed_coordinates=coord_mapping)
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 1a690409d..d4c6df549 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -179,6 +179,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
 
     if iteration_slice is not None:
         iteration_slice = normalize_slice(iteration_slice, shape)
+    print(iteration_slice)
 
     if ghost_layers is None:
         required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
@@ -195,6 +196,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
             current_body = ast.Block([new_loop])
         else:
             slice_component = iteration_slice[loop_coordinate]
+            print(slice_component)
             if type(slice_component) is slice:
                 sc = slice_component
                 new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
@@ -342,7 +344,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
         ast_node: ast before any field accesses are resolved
         loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
                        for GPU kernels: list of 'loop counters' from inner to outer loop
-        loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default
+        loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
 
     Returns:
         base buffer index - required by 'resolve_buffer_accesses' function
@@ -354,15 +356,14 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
         assert len(loops) == len(parents_of_innermost_loop)
         assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
 
-        actual_sizes = [int_div((loop.stop - loop.start), loop.step)
-                        if loop.step != 1 else loop.stop - loop.start for loop in loops]
+        loop_counters = [loop.loop_counter_symbol for loop in loops]
+        loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
 
-        actual_steps = [int_div((loop.loop_counter_symbol - loop.start), loop.step)
-                        if loop.step != 1 else loop.loop_counter_symbol - loop.start for loop in loops]
+    actual_sizes = [int_div((s.stop - s.start), s.step)
+                    if s.step != 1 else s.stop - s.start for s in loop_iterations]
 
-    else:
-        actual_sizes = loop_iterations
-        actual_steps = loop_counters
+    actual_steps = [int_div((ctr - s.start), s.step)
+                    if s.step != 1 else ctr - s.start for ctr, s in zip(loop_counters, loop_iterations)]
 
     field_accesses = ast_node.atoms(Field.Access)
     buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
-- 
GitLab