Skip to content
Snippets Groups Projects
Commit 604bc7c0 authored by Martin Bauer's avatar Martin Bauer
Browse files

Support for CUDA block size specification at runtime

parent 754c7767
No related merge requests found
...@@ -7,7 +7,7 @@ from pystencils.slicing import normalize_slice ...@@ -7,7 +7,7 @@ from pystencils.slicing import normalize_slice
from pystencils.data_types import TypedSymbol, create_type from pystencils.data_types import TypedSymbol, create_type
from functools import partial from functools import partial
from pystencils.sympyextensions import prod from pystencils.sympyextensions import prod, is_integer_sequence
AUTO_BLOCK_SIZE_LIMITING = False AUTO_BLOCK_SIZE_LIMITING = False
...@@ -66,6 +66,9 @@ class AbstractIndexing(abc.ABC): ...@@ -66,6 +66,9 @@ class AbstractIndexing(abc.ABC):
"""Return maximal number of threads per block for launch bounds. If this cannot be determined without """Return maximal number of threads per block for launch bounds. If this cannot be determined without
knowing the array shape return None for unknown """ knowing the array shape return None for unknown """
@abc.abstractmethod
def symbolic_parameters(self):
"""Set of symbols required in call_parameters code"""
# -------------------------------------------- Implementations --------------------------------------------------------- # -------------------------------------------- Implementations ---------------------------------------------------------
...@@ -230,6 +233,8 @@ class BlockIndexing(AbstractIndexing): ...@@ -230,6 +233,8 @@ class BlockIndexing(AbstractIndexing):
@staticmethod @staticmethod
def permute_block_size_according_to_layout(block_size, layout): def permute_block_size_according_to_layout(block_size, layout):
"""Returns modified block_size such that the fastest coordinate gets the biggest block dimension""" """Returns modified block_size such that the fastest coordinate gets the biggest block dimension"""
if not is_integer_sequence(block_size):
return block_size
sorted_block_size = list(sorted(block_size, reverse=True)) sorted_block_size = list(sorted(block_size, reverse=True))
while len(sorted_block_size) > len(layout): while len(sorted_block_size) > len(layout):
sorted_block_size[0] *= sorted_block_size[-1] sorted_block_size[0] *= sorted_block_size[-1]
...@@ -241,7 +246,13 @@ class BlockIndexing(AbstractIndexing): ...@@ -241,7 +246,13 @@ class BlockIndexing(AbstractIndexing):
return tuple(result[:len(layout)]) return tuple(result[:len(layout)])
def max_threads_per_block(self): def max_threads_per_block(self):
return prod(self._block_size) if is_integer_sequence(self._block_size):
return prod(self._block_size)
else:
return None
def symbolic_parameters(self):
return set(b for b in self._block_size if isinstance(b, sp.Symbol))
class LineIndexing(AbstractIndexing): class LineIndexing(AbstractIndexing):
...@@ -293,6 +304,8 @@ class LineIndexing(AbstractIndexing): ...@@ -293,6 +304,8 @@ class LineIndexing(AbstractIndexing):
def max_threads_per_block(self): def max_threads_per_block(self):
return None return None
def symbolic_parameters(self):
return set()
# -------------------------------------- Helper functions -------------------------------------------------------------- # -------------------------------------- Helper functions --------------------------------------------------------------
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment