diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py index f74b6e5071cf90b09074b088eeaf6c52c32d0eb1..9e2df4e0a1183d752f560985ab15acee041f2b76 100644 --- a/pystencils/gpucuda/indexing.py +++ b/pystencils/gpucuda/indexing.py @@ -124,12 +124,18 @@ class BlockIndexing(AbstractIndexing): self._symbolic_shape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape] self._compile_time_block_size = compile_time_block_size + @property + def cuda_indices(self): + block_size = self._block_size if self._compile_time_block_size else BLOCK_DIM + indices = [block_index * bs + thread_idx + for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)] + + return indices[:self._dim] + @property def coordinates(self): offsets = _get_start_from_slice(self._iterationSlice) - block_size = self._block_size if self._compile_time_block_size else BLOCK_DIM - coordinates = [block_index * bs + thread_idx + off - for block_index, bs, thread_idx, off in zip(BLOCK_IDX, block_size, THREAD_IDX, offsets)] + coordinates = [c + off for c, off in zip(self.cuda_indices, offsets)] return coordinates[:self._dim] @@ -159,8 +165,13 @@ class BlockIndexing(AbstractIndexing): def guard(self, kernel_content, arr_shape): arr_shape = arr_shape[:self._dim] - conditions = [c < end - for c, end in zip(self.coordinates, _get_end_from_slice(self._iterationSlice, arr_shape))] + end = _get_end_from_slice(self._iterationSlice, arr_shape) + + conditions = [c < e for c, e in zip(self.coordinates, end)] + for cuda_index, iter_slice in zip(self.cuda_indices, self._iterationSlice): + if iter_slice.step > 1: + conditions.append(sp.Eq(sp.Mod(cuda_index, iter_slice.step), 0)) + condition = conditions[0] for c in conditions[1:]: condition = sp.And(condition, c)