cudajit.py 4.87 KB
Newer Older
1
2
import numpy as np
import pycuda.driver as cuda
Martin Bauer's avatar
Martin Bauer committed
3
import pycuda.autoinit
4
from pycuda.compiler import SourceModule
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.backends.cbackend import generateC
6
from pystencils.transformations import symbolNameToVariableName
7
from pystencils.types import StructType
8
9


10
11
12
13
14
def makePythonFunction(kernelFunctionNode, argumentDict={}):
    """
    Creates a kernel function from an abstract syntax tree which
    was created e.g. by :func:`pystencils.gpucuda.createCUDAKernel`
    or :func:`pystencils.gpucuda.createdIndexedCUDAKernel`
15

16
17
18
19
20
21
22
23
24
    :param kernelFunctionNode: the abstract syntax tree
    :param argumentDict: parameters passed here are already fixed. Remaining parameters have to be passed to the
                        returned kernel functor.
    :return: kernel functor
    """
    code = "#include <cstdint>\n"
    code += "#define FUNC_PREFIX __global__\n"
    code += "#define RESTRICT __restrict__\n\n"
    code += str(generateC(kernelFunctionNode))
25

26
27
28
29
30
31
32
33
    mod = SourceModule(code, options=["-w", "-std=c++11"])
    func = mod.get_function(kernelFunctionNode.functionName)

    def wrapper(**kwargs):
        from copy import copy
        fullArguments = copy(argumentDict)
        fullArguments.update(kwargs)
        shape = _checkArguments(kernelFunctionNode.parameters, fullArguments)
34

35
        dictWithBlockAndThreadNumbers = kernelFunctionNode.getCallParameters(shape)
36

37
38
39
        args = _buildNumpyArgumentList(kernelFunctionNode, fullArguments)
        func(*args, **dictWithBlockAndThreadNumbers)
    return wrapper
40
41


42
def _buildNumpyArgumentList(kernelFunctionNode, argumentDict):
43
    argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    result = []
    for arg in kernelFunctionNode.parameters:
        if arg.isFieldArgument:
            field = argumentDict[arg.fieldName]
            if arg.isFieldPtrArgument:
                result.append(field.gpudata)
            elif arg.isFieldShapeArgument:
                strideArr = np.array(field.strides, dtype=np.int32) / field.dtype.itemsize
                result.append(cuda.In(strideArr))
            elif arg.isFieldStrideArgument:
                shapeArr = np.array(field.shape, dtype=np.int32)
                result.append(cuda.In(shapeArr))
            else:
                assert False
        else:
            param = argumentDict[arg.name]
60
            expectedType = arg.dtype.numpyDtype
61
62
63
64
            result.append(expectedType(param))
    return result


65
66
67
68
69
70
71
72
73
74
75
76
77
78
def _checkArguments(parameterSpecification, argumentDict):
    """
    Checks if parameters passed to kernel match the description in the AST function node.
    If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
    """
    argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
    arrayShapes = set()
    indexArrShapes = set()
    for arg in parameterSpecification:
        if arg.isFieldArgument:
            try:
                fieldArr = argumentDict[arg.fieldName]
            except KeyError:
                raise KeyError("Missing field parameter for kernel call " + arg.fieldName)
Martin Bauer's avatar
Martin Bauer committed
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            symbolicField = arg.field
            if arg.isFieldPtrArgument:
                if symbolicField.hasFixedShape:
                    symbolicFieldShape = tuple(int(i) for i in symbolicField.shape)
                    if isinstance(symbolicField.dtype, StructType):
                        symbolicFieldShape = symbolicFieldShape[:-1]
                    if symbolicFieldShape != fieldArr.shape:
                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
                                         (arg.fieldName, str(fieldArr.shape), str(symbolicField.shape)))
                if symbolicField.hasFixedShape:
                    symbolicFieldStrides = tuple(int(i) * fieldArr.dtype.itemsize for i in symbolicField.strides)
                    if isinstance(symbolicField.dtype, StructType):
                        symbolicFieldStrides = symbolicFieldStrides[:-1]
                    if symbolicFieldStrides != fieldArr.strides:
                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                         (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
                if symbolicField.isIndexField:
                    indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
                else:
                    arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])

    if len(arrayShapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(arrayShapes))
    if len(indexArrShapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(arrayShapes))

    if len(indexArrShapes) > 0:
        return list(indexArrShapes)[0]
    else:
        return list(arrayShapes)[0]
111

Martin Bauer's avatar
Martin Bauer committed
112

113