An error occurred while loading the file. Please try again.
-
Martin Bauer authored
- pystencils can create now a non-compilable kernel that can be analyzed by kerncraft
3b4deebe
kernelcreation.py 5.53 KiB
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.types import TypedSymbol, BasicType, StructType
from pystencils import Field
def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, indexingCreator=BlockIndexing,
iterationSlice=None, ghostLayers=None):
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
allFields = fieldsRead.union(fieldsWritten)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
fieldAccesses = set()
for eq in listOfEquations:
fieldAccesses.update(eq.atoms(Field.Access))
commonShape = getCommonShape(allFields)
if iterationSlice is None:
# determine iteration slice from ghost layers
if ghostLayers is None:
# determine required number of ghost layers from field access
requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(commonShape)
iterationSlice = []
if isinstance(ghostLayers, int):
for i in range(len(commonShape)):
iterationSlice.append(slice(ghostLayers, -ghostLayers if ghostLayers > 0 else None))
else:
for i in range(len(commonShape)):
iterationSlice.append(slice(ghostLayers[i][0], -ghostLayers[i][1] if ghostLayers[i][1] > 0 else None))
indexing = indexingCreator(field=list(allFields)[0], iterationSlice=iterationSlice)
block = Block(assignments)
block = indexing.guard(block, commonShape)
ast = KernelFunction(block, functionName)
ast.globalVariables.update(indexing.indexVariables)
coordMapping = indexing.coordinates
basePointerInfo = [['spatialInner0']]
basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
coordMapping = {f.name: coordMapping for f in allFields}
resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
fieldToBasePointerInfo=basePointerInfos)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
# If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity),
# they are defined here
undefinedLoopCounters = {LoopOverCoordinate.isLoopCounterSymbol(s): s for s in ast.body.undefinedSymbols
if LoopOverCoordinate.isLoopCounterSymbol(s) is not None}
for i, loopCounter in undefinedLoopCounters.items():
ast.body.insertFront(SympyAssignment(loopCounter, indexing.coordinates[i]))
ast.indexing = indexing
return ast
def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel", typeForSymbol=None,
coordinateNames=('x', 'y', 'z'), indexingCreator=BlockIndexing):
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
allFields = fieldsRead.union(fieldsWritten)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
for indexField in indexFields:
indexField.isIndexField = True
assert indexField.spatialDimensions == 1, "Index fields have to be 1D"
nonIndexFields = [f for f in allFields if f not in indexFields]
spatialCoordinates = {f.spatialDimensions for f in nonIndexFields}
assert len(spatialCoordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
spatialCoordinates = list(spatialCoordinates)[0]
def getCoordinateSymbolAssignment(name):
for indexField in indexFields:
assert isinstance(indexField.dtype, StructType), "Index fields have to have a struct datatype"
dataType = indexField.dtype
if dataType.hasElement(name):
rhs = indexField[0](name)
lhs = TypedSymbol(name, BasicType(dataType.getElementType(name)))
return SympyAssignment(lhs, rhs)
raise ValueError("Index %s not found in any of the passed index fields" % (name,))
coordinateSymbolAssignments = [getCoordinateSymbolAssignment(n) for n in coordinateNames[:spatialCoordinates]]
coordinateTypedSymbols = [eq.lhs for eq in coordinateSymbolAssignments]
idxField = list(indexFields)[0]
indexing = indexingCreator(field=idxField, iterationSlice=[slice(None, None, None)] * len(idxField.spatialShape))
functionBody = Block(coordinateSymbolAssignments + assignments)
functionBody = indexing.guard(functionBody, getCommonShape(indexFields))
ast = KernelFunction(functionBody, functionName)
ast.globalVariables.update(indexing.indexVariables)
coordMapping = indexing.coordinates
basePointerInfo = [['spatialInner0']]
basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields}
coordMapping = {f.name: coordMapping for f in indexFields}
coordMapping.update({f.name: coordinateTypedSymbols for f in nonIndexFields})
resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
fieldToBasePointerInfo=basePointerInfos)
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
ast.indexing = indexing
return ast