import numpy as np import pycuda.driver as cuda import pycuda.autoinit from pycuda.compiler import SourceModule from pystencils.backends.cbackend import generateC from pystencils.transformations import symbolNameToVariableName from pystencils.data_types import StructType, getBaseType def makePythonFunction(kernelFunctionNode, argumentDict={}): """ Creates a kernel function from an abstract syntax tree which was created e.g. by :func:`pystencils.gpucuda.createCUDAKernel` or :func:`pystencils.gpucuda.createdIndexedCUDAKernel` :param kernelFunctionNode: the abstract syntax tree :param argumentDict: parameters passed here are already fixed. Remaining parameters have to be passed to the returned kernel functor. :return: kernel functor """ code = "#include <cstdint>\n" code += "#define FUNC_PREFIX __global__\n" code += "#define RESTRICT __restrict__\n\n" code += str(generateC(kernelFunctionNode)) mod = SourceModule(code, options=["-w", "-std=c++11"]) func = mod.get_function(kernelFunctionNode.functionName) parameters = kernelFunctionNode.parameters cache = {} cacheValues = [] def wrapper(**kwargs): key = hash(tuple((k, id(v)) for k, v in kwargs.items())) try: args, dictWithBlockAndThreadNumbers = cache[key] func(*args, **dictWithBlockAndThreadNumbers) except KeyError: fullArguments = argumentDict.copy() fullArguments.update(kwargs) shape = _checkArguments(parameters, fullArguments) indexing = kernelFunctionNode.indexing dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape) dictWithBlockAndThreadNumbers['block'] = tuple(int(i) for i in dictWithBlockAndThreadNumbers['block']) dictWithBlockAndThreadNumbers['grid'] = tuple(int(i) for i in dictWithBlockAndThreadNumbers['grid']) args = _buildNumpyArgumentList(parameters, fullArguments) cache[key] = (args, dictWithBlockAndThreadNumbers) cacheValues.append(kwargs) # keep objects alive such that ids remain unique func(*args, **dictWithBlockAndThreadNumbers) #cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called return wrapper def _buildNumpyArgumentList(parameters, argumentDict): argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()} result = [] for arg in parameters: if arg.isFieldArgument: field = argumentDict[arg.fieldName] if arg.isFieldPtrArgument: actualType = field.dtype expectedType = arg.dtype.baseType.numpyDtype if expectedType != actualType: raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." % (arg.fieldName, expectedType, actualType)) result.append(field.gpudata) elif arg.isFieldStrideArgument: dtype = getBaseType(arg.dtype).numpyDtype strideArr = np.array(field.strides, dtype=dtype) // field.dtype.itemsize result.append(cuda.In(strideArr)) elif arg.isFieldShapeArgument: dtype = getBaseType(arg.dtype).numpyDtype shapeArr = np.array(field.shape, dtype=dtype) result.append(cuda.In(shapeArr)) else: assert False else: param = argumentDict[arg.name] expectedType = arg.dtype.numpyDtype result.append(expectedType.type(param)) assert len(result) == len(parameters) return result def _checkArguments(parameterSpecification, argumentDict): """ Checks if parameters passed to kernel match the description in the AST function node. If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads """ argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()} arrayShapes = set() indexArrShapes = set() for arg in parameterSpecification: if arg.isFieldArgument: try: fieldArr = argumentDict[arg.fieldName] except KeyError: raise KeyError("Missing field parameter for kernel call " + arg.fieldName) symbolicField = arg.field if arg.isFieldPtrArgument: if symbolicField.hasFixedShape: symbolicFieldShape = tuple(int(i) for i in symbolicField.shape) if isinstance(symbolicField.dtype, StructType): symbolicFieldShape = symbolicFieldShape[:-1] if symbolicFieldShape != fieldArr.shape: raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" % (arg.fieldName, str(fieldArr.shape), str(symbolicField.shape))) if symbolicField.hasFixedShape: symbolicFieldStrides = tuple(int(i) * fieldArr.dtype.itemsize for i in symbolicField.strides) if isinstance(symbolicField.dtype, StructType): symbolicFieldStrides = symbolicFieldStrides[:-1] if symbolicFieldStrides != fieldArr.strides: raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides))) if symbolicField.isIndexField: indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) else: arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) if len(arrayShapes) > 1: raise ValueError("All passed arrays have to have the same size " + str(arrayShapes)) if len(indexArrShapes) > 1: raise ValueError("All passed index arrays have to have the same size " + str(arrayShapes)) if len(indexArrShapes) > 0: return list(indexArrShapes)[0] else: return list(arrayShapes)[0]