From 604bc7c084add9d3f5d50c2aec739353f1b38d72 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 14 Jun 2019 13:58:17 +0200
Subject: [PATCH] Support for CUDA block size specification at runtime

---
 pystencils/gpucuda/indexing.py | 17 +++++++++++++++--
 1 file changed, 15 insertions(+), 2 deletions(-)

diff --git a/pystencils/gpucuda/indexing.py b/pystencils/gpucuda/indexing.py
index 80301a5..23a7b51 100644
--- a/pystencils/gpucuda/indexing.py
+++ b/pystencils/gpucuda/indexing.py
@@ -7,7 +7,7 @@ from pystencils.slicing import normalize_slice
 from pystencils.data_types import TypedSymbol, create_type
 from functools import partial
 
-from pystencils.sympyextensions import prod
+from pystencils.sympyextensions import prod, is_integer_sequence
 
 AUTO_BLOCK_SIZE_LIMITING = False
 
@@ -66,6 +66,9 @@ class AbstractIndexing(abc.ABC):
         """Return maximal number of threads per block for launch bounds. If this cannot be determined without
         knowing the array shape return None for unknown """
 
+    @abc.abstractmethod
+    def symbolic_parameters(self):
+        """Set of symbols required in call_parameters code"""
 
 # -------------------------------------------- Implementations ---------------------------------------------------------
 
@@ -230,6 +233,8 @@ class BlockIndexing(AbstractIndexing):
     @staticmethod
     def permute_block_size_according_to_layout(block_size, layout):
         """Returns modified block_size such that the fastest coordinate gets the biggest block dimension"""
+        if not is_integer_sequence(block_size):
+            return block_size
         sorted_block_size = list(sorted(block_size, reverse=True))
         while len(sorted_block_size) > len(layout):
             sorted_block_size[0] *= sorted_block_size[-1]
@@ -241,7 +246,13 @@ class BlockIndexing(AbstractIndexing):
         return tuple(result[:len(layout)])
 
     def max_threads_per_block(self):
-        return prod(self._block_size)
+        if is_integer_sequence(self._block_size):
+            return prod(self._block_size)
+        else:
+            return None
+
+    def symbolic_parameters(self):
+        return set(b for b in self._block_size if isinstance(b, sp.Symbol))
 
 
 class LineIndexing(AbstractIndexing):
@@ -293,6 +304,8 @@ class LineIndexing(AbstractIndexing):
     def max_threads_per_block(self):
         return None
 
+    def symbolic_parameters(self):
+        return set()
 
 # -------------------------------------- Helper functions --------------------------------------------------------------
 
-- 
GitLab