From ff641ec9aba11473f5b84ce17aaf9e09c61cc8ee Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 20 Mar 2017 17:30:52 +0100
Subject: [PATCH] Conditional AST Node & advanced CUDA indexing

- abstraction layer for selecting CUDA block and grid sizes
  - line based (was implemented before)
  - block based (new, more flexible)
-  new conditional (if/else) ast node, which is necessary for indexing schemes (guarding if)
---
 astnodes.py               |  43 +++++++++-
 backends/cbackend.py      |   9 +++
 field.py                  |   2 +-
 gpucuda/cudajit.py        |  20 +++--
 gpucuda/indexing.py       | 161 ++++++++++++++++++++++++++++++++++++++
 gpucuda/kernelcreation.py |  62 ++++++---------
 transformations.py        |  43 +++++-----
 7 files changed, 274 insertions(+), 66 deletions(-)
 create mode 100644 gpucuda/indexing.py

diff --git a/astnodes.py b/astnodes.py
index 91f2e494f..8210771c6 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -21,7 +21,7 @@ class Node(object):
 
     @property
     def undefinedSymbols(self):
-        """Symbols which are use but are not defined inside this node"""
+        """Symbols which are used but are not defined inside this node"""
         raise NotImplementedError()
 
     def atoms(self, argType):
@@ -36,6 +36,47 @@ class Node(object):
         return result
 
 
+class Conditional(Node):
+    """Conditional"""
+    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
+        """
+        Create a new conditional node
+
+        :param conditionExpr: sympy relational expression
+        :param trueBlock: block which is run if conditional is true
+        :param falseBlock: block which is run if conditional is false, or None if not needed
+        """
+        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
+        self.conditionExpr = conditionExpr
+        self.trueBlock = trueBlock
+        self.falseBlock = falseBlock
+
+    @property
+    def args(self):
+        result = [self.conditionExpr, self.trueBlock]
+        if self.falseBlock:
+            result.append(self.falseBlock)
+        return result
+
+    @property
+    def symbolsDefined(self):
+        return set()
+
+    @property
+    def undefinedSymbols(self):
+        result = self.trueBlock.undefinedSymbols
+        if self.falseBlock:
+            result = result.update(self.falseBlock.undefinedSymbols)
+        result.update(self.conditionExpr.atoms(sp.Symbol))
+        return result
+
+    def __str__(self):
+        return 'if:({!s}) '.format(self.conditionExpr)
+
+    def __repr__(self):
+        return 'if:({!r}) '.format(self.conditionExpr)
+
+
 class KernelFunction(Node):
 
     class Argument:
diff --git a/backends/cbackend.py b/backends/cbackend.py
index c85e9be72..821a94733 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -112,6 +112,15 @@ class CBackend(object):
     def _print_CustomCppCode(self, node):
         return node.code
 
+    def _print_Conditional(self, node):
+        conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
+        trueBlock = self._print_Block(node.trueBlock)
+        result = "if (%s) \n %s " % (conditionExpr, trueBlock)
+        if node.falseBlock:
+            falseBlock = self._print_Block(node.falseBlock)
+            result += "else " + falseBlock
+        return result
+
 
 # ------------------------------------------ Helper function & classes -------------------------------------------------
 
diff --git a/field.py b/field.py
index 66cc1181d..83a6ba4ff 100644
--- a/field.py
+++ b/field.py
@@ -3,7 +3,7 @@ import numpy as np
 import sympy as sp
 from sympy.core.cache import cacheit
 from sympy.tensor import IndexedBase
-from pystencils.types import TypedSymbol, createType, StructType
+from pystencils.types import TypedSymbol, createType
 
 
 class Field(object):
diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py
index a0355d26c..a99aa69cc 100644
--- a/gpucuda/cudajit.py
+++ b/gpucuda/cudajit.py
@@ -26,24 +26,26 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
     mod = SourceModule(code, options=["-w", "-std=c++11"])
     func = mod.get_function(kernelFunctionNode.functionName)
 
+    parameters = kernelFunctionNode.parameters
+
     def wrapper(**kwargs):
         from copy import copy
         fullArguments = copy(argumentDict)
         fullArguments.update(kwargs)
-        shape = _checkArguments(kernelFunctionNode.parameters, fullArguments)
-
-        dictWithBlockAndThreadNumbers = kernelFunctionNode.getCallParameters(shape)
+        shape = _checkArguments(parameters, fullArguments)
 
-        args = _buildNumpyArgumentList(kernelFunctionNode, fullArguments)
+        indexing = kernelFunctionNode.indexing
+        dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape)
+        args = _buildNumpyArgumentList(parameters, fullArguments)
         func(*args, **dictWithBlockAndThreadNumbers)
-        # cuda.Context.synchronize() #  useful for debugging, to get errors right after kernel was called
+        #cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
     return wrapper
 
 
-def _buildNumpyArgumentList(kernelFunctionNode, argumentDict):
+def _buildNumpyArgumentList(parameters, argumentDict):
     argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
     result = []
-    for arg in kernelFunctionNode.parameters:
+    for arg in parameters:
         if arg.isFieldArgument:
             field = argumentDict[arg.fieldName]
             if arg.isFieldPtrArgument:
@@ -52,6 +54,10 @@ def _buildNumpyArgumentList(kernelFunctionNode, argumentDict):
                 dtype = getBaseType(arg.dtype).numpyDtype
                 strideArr = np.array(field.strides, dtype=dtype) // field.dtype.itemsize
                 result.append(cuda.In(strideArr))
+            elif arg.isFieldShapeArgument:
+                dtype = getBaseType(arg.dtype).numpyDtype
+                shapeArr = np.array(field.shape, dtype=dtype)
+                result.append(cuda.In(shapeArr))
             else:
                 assert False
         else:
diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py
new file mode 100644
index 000000000..f63e633de
--- /dev/null
+++ b/gpucuda/indexing.py
@@ -0,0 +1,161 @@
+import sympy as sp
+import math
+import pycuda.driver as cuda
+import pycuda.autoinit
+
+from pystencils.astnodes import Conditional, Block
+
+BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
+THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z"))
+
+# Part 1:
+#  given a field and the number of ghost layers, return the x, y and z coordinates
+#  dependent on CUDA thread and block indices
+
+# Part 2:
+#  given the actual field size, determine the call parameters i.e. # of blocks and threads
+
+
+class LineIndexing:
+    def __init__(self, field, ghostLayers):
+        availableIndices = [THREAD_IDX[0]] + BLOCK_IDX
+        if field.spatialDimensions > 4:
+            raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions")
+
+        coordinates = availableIndices[:field.spatialDimensions]
+
+        fastestCoordinate = field.layout[-1]
+        coordinates[0], coordinates[fastestCoordinate] = coordinates[fastestCoordinate], coordinates[0]
+
+        self._coordiantesNoGhostLayer = coordinates
+        self._coordinates = [i + ghostLayers for i in coordinates]
+        self._ghostLayers = ghostLayers
+
+    @property
+    def coordinates(self):
+        return self._coordinates
+
+    def getCallParameters(self, arrShape):
+        def getShapeOfCudaIdx(cudaIdx):
+            if cudaIdx not in self._coordiantesNoGhostLayer:
+                return 1
+            else:
+                return arrShape[self._coordiantesNoGhostLayer.index(cudaIdx)] - 2 * self._ghostLayers
+
+        return {'block': tuple([getShapeOfCudaIdx(idx) for idx in THREAD_IDX]),
+                'grid': tuple([getShapeOfCudaIdx(idx) for idx in BLOCK_IDX])}
+
+    def guard(self, kernelContent, arrShape):
+        return kernelContent
+
+    @property
+    def indexVariables(self):
+        return BLOCK_IDX + THREAD_IDX
+
+
+class BlockIndexing:
+    def __init__(self, field, ghostLayers, blockSize=(256, 8, 1), permuteBlockSizeDependentOnLayout=True):
+        if field.spatialDimensions > 3:
+            raise NotImplementedError("This indexing scheme supports at most 3 spatial dimensions")
+
+        if permuteBlockSizeDependentOnLayout:
+            blockSize = self.permuteBlockSizeAccordingToLayout(blockSize, field.layout)
+
+        self._blockSize = self.limitBlockSizeToDeviceMaximum(blockSize)
+        self._coordinates = [blockIndex * bs + threadIndex + ghostLayers
+                             for blockIndex, bs, threadIndex in zip(BLOCK_IDX, blockSize, THREAD_IDX)]
+
+        self._coordinates = self._coordinates[:field.spatialDimensions]
+        self._ghostLayers = ghostLayers
+
+    @staticmethod
+    def limitBlockSizeToDeviceMaximum(blockSize):
+        # Get device limits
+        da = cuda.device_attribute
+        device = cuda.Context.get_device()
+
+        blockSize = list(blockSize)
+        maxThreads = device.get_attribute(da.MAX_THREADS_PER_BLOCK)
+        maxBlockSize = [device.get_attribute(a)
+                        for a in (da.MAX_BLOCK_DIM_X, da.MAX_BLOCK_DIM_Y, da.MAX_BLOCK_DIM_Z)]
+
+        def prod(seq):
+            result = 1
+            for e in seq:
+                result *= e
+            return result
+
+        def getIndexOfTooBigElement(blockSize):
+            for i, bs in enumerate(blockSize):
+                if bs > maxBlockSize[i]:
+                    return i
+            return None
+
+        def getIndexOfTooSmallElement(blockSize):
+            for i, bs in enumerate(blockSize):
+                if bs // 2 <= maxBlockSize[i]:
+                    return i
+            return None
+
+        # Reduce the total number of threads if necessary
+        while prod(blockSize) > maxThreads:
+            itemToReduce = blockSize.index(max(blockSize))
+            for i, bs in enumerate(blockSize):
+                if bs > maxBlockSize[i]:
+                    itemToReduce = i
+            blockSize[itemToReduce] //= 2
+
+        # Cap individual elements
+        tooBigElementIndex = getIndexOfTooBigElement(blockSize)
+        while tooBigElementIndex is not None:
+            tooSmallElementIndex = getIndexOfTooSmallElement(blockSize)
+            blockSize[tooSmallElementIndex] *= 2
+            blockSize[tooBigElementIndex] //= 2
+            tooBigElementIndex = getIndexOfTooBigElement(blockSize)
+
+        return tuple(blockSize)
+
+    @staticmethod
+    def permuteBlockSizeAccordingToLayout(blockSize, layout):
+        """The fastest coordinate gets the biggest block dimension"""
+        sortedBlockSize = list(sorted(blockSize, reverse=True))
+        while len(sortedBlockSize) > len(layout):
+            sortedBlockSize[0] *= sortedBlockSize[-1]
+            sortedBlockSize = sortedBlockSize[:-1]
+
+        result = list(blockSize)
+        for l, bs in zip(reversed(layout), sortedBlockSize):
+            result[l] = bs
+        return tuple(result[:len(layout)])
+
+    @property
+    def coordinates(self):
+        return self._coordinates
+
+    def getCallParameters(self, arrShape):
+        dim = len(self._coordinates)
+        arrShape = arrShape[:dim]
+        grid = tuple(math.ceil(length / blockSize) for length, blockSize in zip(arrShape, self._blockSize))
+        extendBs = (1,) * (3 - len(self._blockSize))
+        extendGr = (1,) * (3 - len(grid))
+        return {'block': self._blockSize + extendBs,
+                'grid': grid + extendGr}
+
+    def guard(self, kernelContent, arrShape):
+        dim = len(self._coordinates)
+        arrShape = arrShape[:dim]
+        conditions = [c < shapeComponent - self._ghostLayers
+                      for c, shapeComponent in zip(self._coordinates, arrShape)]
+        condition = conditions[0]
+        for c in conditions[1:]:
+            condition = sp.And(condition, c)
+        return Block([Conditional(condition, kernelContent)])
+
+    @property
+    def indexVariables(self):
+        return BLOCK_IDX + THREAD_IDX
+
+if __name__ == '__main__':
+    bs = BlockIndexing.permuteBlockSizeAccordingToLayout((256, 8, 1), (0,))
+    bs = BlockIndexing.limitBlockSizeToDeviceMaximum(bs)
+    print(bs)
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 664752935..5d4a3ac5b 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -1,47 +1,28 @@
-import sympy as sp
-
-from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo
+from pystencils.gpucuda.indexing import BlockIndexing, LineIndexing
+from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape
 from pystencils.astnodes import Block, KernelFunction, SympyAssignment
-from pystencils import Field
 from pystencils.types import TypedSymbol, BasicType, StructType
-
-BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
-THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z"))
-
-
-def getLinewiseCoordinates(field, ghostLayers):
-    availableIndices = [THREAD_IDX[0]] + BLOCK_IDX
-    assert field.spatialDimensions <= 4, "This indexing scheme supports at most 4 spatial dimensions"
-    result = availableIndices[:field.spatialDimensions]
-
-    fastestCoordinate = field.layout[-1]
-    result[0], result[fastestCoordinate] = result[fastestCoordinate], result[0]
-
-    def getCallParameters(arrShape):
-        def getShapeOfCudaIdx(cudaIdx):
-            if cudaIdx not in result:
-                return 1
-            else:
-                return arrShape[result.index(cudaIdx)] - 2 * ghostLayers
-
-        return {'block': tuple([getShapeOfCudaIdx(idx) for idx in THREAD_IDX]),
-                'grid': tuple([getShapeOfCudaIdx(idx) for idx in BLOCK_IDX])}
-
-    return [i + ghostLayers for i in result], getCallParameters
+from pystencils import Field
 
 
-def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None):
+def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, indexingCreator=BlockIndexing):
     fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
     allFields = fieldsRead.union(fieldsWritten)
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
-    ast = KernelFunction(Block(assignments), allFields, functionName)
-    ast.globalVariables.update(BLOCK_IDX + THREAD_IDX)
+    fieldAccesses = set()
+    for eq in listOfEquations:
+        fieldAccesses.update(eq.atoms(Field.Access))
 
-    fieldAccesses = ast.atoms(Field.Access)
     requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
+    indexing = indexingCreator(list(fieldsRead)[0], requiredGhostLayers)
 
-    coordMapping, getCallParameters = getLinewiseCoordinates(list(fieldsRead)[0], requiredGhostLayers)
+    block = Block(assignments)
+    block = indexing.guard(block, getCommonShape(allFields))
+    ast = KernelFunction(block, allFields, functionName)
+    ast.globalVariables.update(indexing.indexVariables)
+
+    coordMapping = indexing.coordinates
     basePointerInfo = [['spatialInner0']]
     basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
 
@@ -50,12 +31,12 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None)
                          fieldToBasePointerInfo=basePointerInfos)
     # add the function which determines #blocks and #threads as additional member to KernelFunction node
     # this is used by the jit
-    ast.getCallParameters = getCallParameters
+    ast.indexing = indexing
     return ast
 
 
 def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
-                             coordinateNames=('x', 'y', 'z')):
+                             coordinateNames=('x', 'y', 'z'), indexingCreator=BlockIndexing):
     fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
     allFields = fieldsRead.union(fieldsWritten)
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
@@ -82,11 +63,14 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel"
     coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
     coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]
 
+    indexing = indexingCreator(list(indexFields)[0], ghostLayers=0)
+
     functionBody = Block(coordinateSymbolAssignments + assignments)
+    functionBody = indexing.guard(functionBody, getCommonShape(indexFields))
     ast = KernelFunction(functionBody, allFields, functionName)
-    ast.globalVariables.update(BLOCK_IDX + THREAD_IDX)
+    ast.globalVariables.update(indexing.indexVariables)
 
-    coordMapping, getCallParameters = getLinewiseCoordinates(list(indexFields)[0], ghostLayers=0)
+    coordMapping = indexing.coordinates
     basePointerInfo = [['spatialInner0']]
     basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
 
@@ -96,5 +80,5 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel"
                          fieldToBasePointerInfo=basePointerInfos)
     # add the function which determines #blocks and #threads as additional member to KernelFunction node
     # this is used by the jit
-    ast.getCallParameters = getCallParameters
-    return ast
\ No newline at end of file
+    ast.indexing = indexing
+    return ast
diff --git a/transformations.py b/transformations.py
index 73b684a3c..a48125dea 100644
--- a/transformations.py
+++ b/transformations.py
@@ -24,6 +24,30 @@ def fastSubs(term, subsDict):
     return visit(term)
 
 
+def getCommonShape(fieldSet):
+    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
+    ValueError is raised"""
+    nrOfFixedShapedFields = 0
+    for f in fieldSet:
+        if f.hasFixedShape:
+            nrOfFixedShapedFields += 1
+
+    if nrOfFixedShapedFields > 0 and nrOfFixedShapedFields != len(fieldSet):
+        fixedFieldNames = ",".join([f.name for f in fieldSet if f.hasFixedShape])
+        varFieldNames = ",".join([f.name for f in fieldSet if not f.hasFixedShape])
+        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
+        msg += "Variable shaped: %s \nFixed shaped:    %s" % (varFieldNames, fixedFieldNames)
+        raise ValueError(msg)
+
+    shapeSet = set([f.spatialShape for f in fieldSet])
+    if nrOfFixedShapedFields == len(fieldSet):
+        if len(shapeSet) != 1:
+            raise ValueError("Differently sized field accesses in loop body: " + str(shapeSet))
+
+    shape = list(sorted(shapeSet, key=lambda e: str(e[0])))[0]
+    return shape
+
+
 def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None, loopOrder=None):
     """
     Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
@@ -45,24 +69,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
     if loopOrder is None:
         loopOrder = getOptimalLoopOrdering(fields)
 
-    nrOfFixedShapedFields = 0
-    for f in fields:
-        if f.hasFixedShape:
-            nrOfFixedShapedFields += 1
-
-    if nrOfFixedShapedFields > 0 and nrOfFixedShapedFields != len(fields):
-        fixedFieldNames = ",".join([f.name for f in fields if f.hasFixedShape])
-        varFieldNames = ",".join([f.name for f in fields if not f.hasFixedShape])
-        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
-        msg += "Variable shaped: %s \nFixed shaped:    %s" % (varFieldNames, fixedFieldNames)
-        raise ValueError(msg)
-
-    shapeSet = set([f.spatialShape for f in fields])
-    if nrOfFixedShapedFields == len(fields):
-        if len(shapeSet) != 1:
-            raise ValueError("Differently sized field accesses in loop body: " + str(shapeSet))
-
-    shape = list(sorted(shapeSet, key=lambda e: str(e[0])))[0]
+    shape = getCommonShape(fields)
 
     if iterationSlice is not None:
         iterationSlice = normalizeSlice(iterationSlice, shape)
-- 
GitLab