Skip to content
Snippets Groups Projects
kernelcreation.py 5.72 KiB
Newer Older
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape, \
    substituteArrayAccessesWithConstants
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.data_types import TypedSymbol, BasicType, StructType
from pystencils import Field
def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, indexingCreator=BlockIndexing,
                     iterationSlice=None, ghostLayers=None):
    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
    allFields = fieldsRead.union(fieldsWritten)
    readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
    fieldAccesses = set()
    for eq in listOfEquations:
        fieldAccesses.update(eq.atoms(Field.Access))
    commonShape = getCommonShape(allFields)
    if iterationSlice is None:
        # determine iteration slice from ghost layers
        if ghostLayers is None:
            # determine required number of ghost layers from field access
            requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
            ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(commonShape)
        iterationSlice = []
        if isinstance(ghostLayers, int):
            for i in range(len(commonShape)):
                iterationSlice.append(slice(ghostLayers, -ghostLayers if ghostLayers > 0 else None))
        else:
            for i in range(len(commonShape)):
                iterationSlice.append(slice(ghostLayers[i][0], -ghostLayers[i][1] if ghostLayers[i][1] > 0 else None))
    indexing = indexingCreator(field=list(allFields)[0], iterationSlice=iterationSlice)
    block = Block(assignments)
    block = indexing.guard(block, commonShape)
    ast = KernelFunction(block, functionName=functionName, ghostLayers=ghostLayers)
    ast.globalVariables.update(indexing.indexVariables)

    coordMapping = indexing.coordinates
    basePointerInfo = [['spatialInner0']]
    basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}

    coordMapping = {f.name: coordMapping for f in allFields}
    resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
                         fieldToBasePointerInfo=basePointerInfos)

    substituteArrayAccessesWithConstants(ast)

    # add the function which determines #blocks and #threads as additional member to KernelFunction node
    # this is used by the jit

    # If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity),
    # they are defined here
    undefinedLoopCounters = {LoopOverCoordinate.isLoopCounterSymbol(s): s for s in ast.body.undefinedSymbols
                             if LoopOverCoordinate.isLoopCounterSymbol(s) is not None}
    for i, loopCounter in undefinedLoopCounters.items():
        ast.body.insertFront(SympyAssignment(loopCounter, indexing.coordinates[i]))

    ast.indexing = indexing
    return ast


def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
                             coordinateNames=('x', 'y', 'z'), indexingCreator=BlockIndexing):
    fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
    allFields = fieldsRead.union(fieldsWritten)
    readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
    for indexField in indexFields:
        indexField.isIndexField = True
        assert indexField.spatialDimensions == 1, "Index fields have to be 1D"

    nonIndexFields = [f for f in allFields if f not in indexFields]
    spatialCoordinates = {f.spatialDimensions for f in nonIndexFields}
    assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatialCoordinates = list(spatialCoordinates)[0]

    def getCoordinateSymbolAssignment(name):
        for indexField in indexFields:
            assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype"
            dataType = indexField.dtype
            if dataType.hasElement(name):
                rhs = indexField[0](name)
                lhs = TypedSymbol(name, BasicType(dataType.getElementType(name)))
                return SympyAssignment(lhs, rhs)
        raise ValueError("Index %s not found in any of the passed index fields" % (name,))

    coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
    coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]

    idxField = list(indexFields)[0]
    indexing = indexingCreator(field=idxField, iterationSlice=[slice(None, None, None)] * len(idxField.spatialShape))
    functionBody = Block(coordinateSymbolAssignments + assignments)
    functionBody = indexing.guard(functionBody, getCommonShape(indexFields))
    ast = KernelFunction(functionBody, functionName=functionName)
    ast.globalVariables.update(indexing.indexVariables)
    coordMapping = indexing.coordinates
    basePointerInfo = [['spatialInner0']]
    basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}

    coordMapping = {f.name: coordMapping for f in indexFields}
    coordMapping.update({f.name: coordinateTypedSymbols for f in nonIndexFields})
    resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
                         fieldToBasePointerInfo=basePointerInfos)
    substituteArrayAccessesWithConstants(ast)

    # add the function which determines #blocks and #threads as additional member to KernelFunction node
    # this is used by the jit
    ast.indexing = indexing
    return ast