diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py index f57ded792780873f1949e9216ae4b2f7531abd86..0b18c1359f88b7ee57a51d359123611e3da45270 100644 --- a/gpucuda/cudajit.py +++ b/gpucuda/cudajit.py @@ -41,7 +41,7 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}): shape = _checkArguments(parameters, fullArguments) indexing = kernelFunctionNode.indexing - dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape) + dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape, func) args = _buildNumpyArgumentList(parameters, fullArguments) cache[key] = (args, dictWithBlockAndThreadNumbers) diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py index 651340d6de5f46f70755534dfe5d52a9fdd08c6e..34de23a2e2414fb6fdd0ed8de7acd2aa0a49eb7b 100644 --- a/gpucuda/indexing.py +++ b/gpucuda/indexing.py @@ -1,4 +1,5 @@ import abc + import sympy as sp import math import pycuda.driver as cuda @@ -30,10 +31,12 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): return BLOCK_IDX + THREAD_IDX @abc.abstractmethod - def getCallParameters(self, arrShape): + def getCallParameters(self, arrShape, functionToCall): """ Determine grid and block size for kernel call :param arrShape: the numeric (not symbolic) shape of the array + :param functionToCall: compile kernel function that should be called. Use this object to get information + about required resources like number of registers :return: dict with keys 'blocks' and 'threads' with tuple values for number of (x,y,z) threads and blocks the kernel should be started with """ @@ -84,7 +87,7 @@ class BlockIndexing(AbstractIndexing): return coordinates[:self._dim] - def getCallParameters(self, arrShape): + def getCallParameters(self, arrShape, functionToCall): substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None} widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice), @@ -94,6 +97,7 @@ class BlockIndexing(AbstractIndexing): grid = tuple(math.ceil(length / blockSize) for length, blockSize in zip(widths, self._blockSize)) extendBs = (1,) * (3 - len(self._blockSize)) extendGr = (1,) * (3 - len(grid)) + return {'block': self._blockSize + extendBs, 'grid': grid + extendGr} @@ -160,6 +164,32 @@ class BlockIndexing(AbstractIndexing): return tuple(blockSize) + @staticmethod + def limitBlockSizeByRegisterRestriction(blockSize, requiredRegistersPerThread, device=None): + """Shrinks the blockSize if there are too many registers used per multiprocessor. + This is not done automatically, since the requiredRegistersPerThread are not known before compilation. + They can be obtained by ``func.num_regs`` from a pycuda function. + :returns smaller blockSize if too many registers are used. + """ + da = cuda.device_attribute + if device is None: + device = cuda.Context.get_device() + availableRegistersPerMP = device.get_attribute(da.MAX_REGISTERS_PER_MULTIPROCESSOR) + + block = blockSize + + while True: + numThreads = 1 + for t in block: + numThreads *= t + requiredRegistersPerMT = numThreads * requiredRegistersPerThread + if requiredRegistersPerMT <= availableRegistersPerMP: + return block + else: + largestGridEntryIdx = max(range(len(block)), key=lambda e: block[e]) + assert block[largestGridEntryIdx] >= 2 + block[largestGridEntryIdx] //= 2 + @staticmethod def permuteBlockSizeAccordingToLayout(blockSize, layout): """Returns modified blockSize such that the fastest coordinate gets the biggest block dimension""" @@ -200,7 +230,7 @@ class LineIndexing(AbstractIndexing): def coordinates(self): return [i + offset for i, offset in zip(self._coordinates, _getStartFromSlice(self._iterationSlice))] - def getCallParameters(self, arrShape): + def getCallParameters(self, arrShape, functionToCall): substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None} widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice), @@ -243,4 +273,5 @@ def _getEndFromSlice(iterationSlice, arrShape): else: assert isinstance(sliceComponent, int) res.append(sliceComponent + 1) - return res \ No newline at end of file + return res +