diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py index 9e2df4e0a1183d752f560985ab15acee041f2b76..85b506e90b2965624219d696d0a31fb1acd6ac80 100644 --- a/pystencils/gpucuda/indexing.py +++ b/pystencils/gpucuda/indexing.py @@ -177,6 +177,9 @@ class BlockIndexing(AbstractIndexing): condition = sp.And(condition, c) return Block([Conditional(condition, kernel_content)]) + def iteration_space(self, arr_shape): + return _iteration_space(self._iterationSlice, arr_shape) + @staticmethod def limit_block_size_by_register_restriction(block_size, required_registers_per_thread, device=None): """Shrinks the block_size if there are too many registers used per multiprocessor. @@ -284,6 +287,9 @@ class LineIndexing(AbstractIndexing): def symbolic_parameters(self): return set() + def iteration_space(self, arr_shape): + return _iteration_space(self._iterationSlice, arr_shape) + # -------------------------------------- Helper functions -------------------------------------------------------------- @@ -310,6 +316,13 @@ def _get_end_from_slice(iteration_slice, arr_shape): return res +def _iteration_space(iteration_slice, arr_shape): + starts = _get_start_from_slice(iteration_slice) + ends = _get_end_from_slice(iteration_slice, arr_shape) + steps = [s.step for s in iteration_slice] + return [slice(start, end, step) for start, end, step in zip(starts, ends, steps)] + + def indexing_creator_from_params(gpu_indexing, gpu_indexing_params): if isinstance(gpu_indexing, str): if gpu_indexing == 'block': diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index 595cf8cb53a7bcd30c3cb09b80ac0f393bd53fd0..0aa265362f76bad3c1e20596d2a497c81b593be5 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -92,7 +92,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], coord_mapping = {f.name: cell_idx_symbols for f in all_fields} if any(FieldType.is_buffer(f) for f in all_fields): - resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, common_shape), read_only_fields) + iteration_space = indexing.iteration_space(common_shape) + resolve_buffer_accesses(ast, get_base_buffer_index(ast, cell_idx_symbols, iteration_space), read_only_fields) resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info, field_to_fixed_coordinates=coord_mapping) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 1a690409d2696af5aafa8461a38ee20eb7a655c1..d4c6df5492bdbe1be98b1df89a7c2b2cad2c8c02 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -179,6 +179,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or if iteration_slice is not None: iteration_slice = normalize_slice(iteration_slice, shape) + print(iteration_slice) if ghost_layers is None: required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) @@ -195,6 +196,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or current_body = ast.Block([new_loop]) else: slice_component = iteration_slice[loop_coordinate] + print(slice_component) if type(slice_component) is slice: sc = slice_component new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step) @@ -342,7 +344,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ast_node: ast before any field accesses are resolved loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) for GPU kernels: list of 'loop counters' from inner to outer loop - loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default + loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default Returns: base buffer index - required by 'resolve_buffer_accesses' function @@ -354,15 +356,14 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): assert len(loops) == len(parents_of_innermost_loop) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) - actual_sizes = [int_div((loop.stop - loop.start), loop.step) - if loop.step != 1 else loop.stop - loop.start for loop in loops] + loop_counters = [loop.loop_counter_symbol for loop in loops] + loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops] - actual_steps = [int_div((loop.loop_counter_symbol - loop.start), loop.step) - if loop.step != 1 else loop.loop_counter_symbol - loop.start for loop in loops] + actual_sizes = [int_div((s.stop - s.start), s.step) + if s.step != 1 else s.stop - s.start for s in loop_iterations] - else: - actual_sizes = loop_iterations - actual_steps = loop_counters + actual_steps = [int_div((ctr - s.start), s.step) + if s.step != 1 else ctr - s.start for ctr, s in zip(loop_counters, loop_iterations)] field_accesses = ast_node.atoms(Field.Access) buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}