From f780b06678a1298927e79b9085531221f14c99bb Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Fri, 17 Mar 2023 13:53:06 +0100 Subject: [PATCH] Fix buffers for GPU --- pystencils/gpucuda/indexing.py | 13 +++++++++++++ pystencils/gpucuda/kernelcreation.py | 3 ++- pystencils/transformations.py | 17 +++++++++-------- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py index 9e2df4e0a..85b506e90 100644 --- a/pystencils/gpucuda/indexing.py +++ b/pystencils/gpucuda/indexing.py @@ -177,6 +177,9 @@ class BlockIndexing(AbstractIndexing): condition = sp.And(condition, c) return Block([Conditional(condition, kernel_content)]) + def iteration_space(self, arr_shape): + return _iteration_space(self._iterationSlice, arr_shape) + @staticmethod def limit_block_size_by_register_restriction(block_size, required_registers_per_thread, device=None): """Shrinks the block_size if there are too many registers used per multiprocessor. @@ -284,6 +287,9 @@ class LineIndexing(AbstractIndexing): def symbolic_parameters(self): return set() + def iteration_space(self, arr_shape): + return _iteration_space(self._iterationSlice, arr_shape) + # -------------------------------------- Helper functions -------------------------------------------------------------- @@ -310,6 +316,13 @@ def _get_end_from_slice(iteration_slice, arr_shape): return res +def _iteration_space(iteration_slice, arr_shape): + starts = _get_start_from_slice(iteration_slice) + ends = _get_end_from_slice(iteration_slice, arr_shape) + steps = [s.step for s in iteration_slice] + return [slice(start, end, step) for start, end, step in zip(starts, ends, steps)] + + def indexing_creator_from_params(gpu_indexing, gpu_indexing_params): if isinstance(gpu_indexing, str): if gpu_indexing == 'block': diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index 595cf8cb5..0aa265362 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -92,7 +92,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], coord_mapping = {f.name: cell_idx_symbols for f in all_fields} if any(FieldType.is_buffer(f) for f in all_fields): - resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, common_shape), read_only_fields) + iteration_space = indexing.iteration_space(common_shape) + resolve_buffer_accesses(ast, get_base_buffer_index(ast, cell_idx_symbols, iteration_space), read_only_fields) resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info, field_to_fixed_coordinates=coord_mapping) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 1a690409d..d4c6df549 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -179,6 +179,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or if iteration_slice is not None: iteration_slice = normalize_slice(iteration_slice, shape) + print(iteration_slice) if ghost_layers is None: required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) @@ -195,6 +196,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or current_body = ast.Block([new_loop]) else: slice_component = iteration_slice[loop_coordinate] + print(slice_component) if type(slice_component) is slice: sc = slice_component new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step) @@ -342,7 +344,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ast_node: ast before any field accesses are resolved loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) for GPU kernels: list of 'loop counters' from inner to outer loop - loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default + loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default Returns: base buffer index - required by 'resolve_buffer_accesses' function @@ -354,15 +356,14 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): assert len(loops) == len(parents_of_innermost_loop) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) - actual_sizes = [int_div((loop.stop - loop.start), loop.step) - if loop.step != 1 else loop.stop - loop.start for loop in loops] + loop_counters = [loop.loop_counter_symbol for loop in loops] + loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops] - actual_steps = [int_div((loop.loop_counter_symbol - loop.start), loop.step) - if loop.step != 1 else loop.loop_counter_symbol - loop.start for loop in loops] + actual_sizes = [int_div((s.stop - s.start), s.step) + if s.step != 1 else s.stop - s.start for s in loop_iterations] - else: - actual_sizes = loop_iterations - actual_steps = loop_counters + actual_steps = [int_div((ctr - s.start), s.step) + if s.step != 1 else ctr - s.start for ctr, s in zip(loop_counters, loop_iterations)] field_accesses = ast_node.atoms(Field.Access) buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} -- GitLab