From 604bc7c084add9d3f5d50c2aec739353f1b38d72 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Fri, 14 Jun 2019 13:58:17 +0200 Subject: [PATCH] Support for CUDA block size specification at runtime --- pystencils/gpucuda/indexing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py index 80301a5..23a7b51 100644 --- a/pystencils/gpucuda/indexing.py +++ b/pystencils/gpucuda/indexing.py @@ -7,7 +7,7 @@ from pystencils.slicing import normalize_slice from pystencils.data_types import TypedSymbol, create_type from functools import partial -from pystencils.sympyextensions import prod +from pystencils.sympyextensions import prod, is_integer_sequence AUTO_BLOCK_SIZE_LIMITING = False @@ -66,6 +66,9 @@ class AbstractIndexing(abc.ABC): """Return maximal number of threads per block for launch bounds. If this cannot be determined without knowing the array shape return None for unknown """ + @abc.abstractmethod + def symbolic_parameters(self): + """Set of symbols required in call_parameters code""" # -------------------------------------------- Implementations --------------------------------------------------------- @@ -230,6 +233,8 @@ class BlockIndexing(AbstractIndexing): @staticmethod def permute_block_size_according_to_layout(block_size, layout): """Returns modified block_size such that the fastest coordinate gets the biggest block dimension""" + if not is_integer_sequence(block_size): + return block_size sorted_block_size = list(sorted(block_size, reverse=True)) while len(sorted_block_size) > len(layout): sorted_block_size[0] *= sorted_block_size[-1] @@ -241,7 +246,13 @@ class BlockIndexing(AbstractIndexing): return tuple(result[:len(layout)]) def max_threads_per_block(self): - return prod(self._block_size) + if is_integer_sequence(self._block_size): + return prod(self._block_size) + else: + return None + + def symbolic_parameters(self): + return set(b for b in self._block_size if isinstance(b, sp.Symbol)) class LineIndexing(AbstractIndexing): @@ -293,6 +304,8 @@ class LineIndexing(AbstractIndexing): def max_threads_per_block(self): return None + def symbolic_parameters(self): + return set() # -------------------------------------- Helper functions -------------------------------------------------------------- -- GitLab