cudajit.py 5.9 KB
Newer Older
1
2
import numpy as np
import pycuda.driver as cuda
Martin Bauer's avatar
Martin Bauer committed
3
import pycuda.autoinit
4
from pycuda.compiler import SourceModule
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.backends.cbackend import generateC
6
from pystencils.transformations import symbolNameToVariableName
7
from pystencils.types import StructType, getBaseType
8
9


10
11
12
13
14
def makePythonFunction(kernelFunctionNode, argumentDict={}):
    """
    Creates a kernel function from an abstract syntax tree which
    was created e.g. by :func:`pystencils.gpucuda.createCUDAKernel`
    or :func:`pystencils.gpucuda.createdIndexedCUDAKernel`
15

16
17
18
19
20
21
22
23
24
    :param kernelFunctionNode: the abstract syntax tree
    :param argumentDict: parameters passed here are already fixed. Remaining parameters have to be passed to the
                        returned kernel functor.
    :return: kernel functor
    """
    code = "#include <cstdint>\n"
    code += "#define FUNC_PREFIX __global__\n"
    code += "#define RESTRICT __restrict__\n\n"
    code += str(generateC(kernelFunctionNode))
25

26
27
28
    mod = SourceModule(code, options=["-w", "-std=c++11"])
    func = mod.get_function(kernelFunctionNode.functionName)

29
30
    parameters = kernelFunctionNode.parameters

31
    cache = {}
Martin Bauer's avatar
Martin Bauer committed
32
    cacheValues = []
33

34
    def wrapper(**kwargs):
35
36
37
38
39
40
41
42
43
44
        key = hash(tuple((k, id(v)) for k, v in kwargs.items()))
        try:
            args, dictWithBlockAndThreadNumbers = cache[key]
            func(*args, **dictWithBlockAndThreadNumbers)
        except KeyError:
            fullArguments = argumentDict.copy()
            fullArguments.update(kwargs)
            shape = _checkArguments(parameters, fullArguments)

            indexing = kernelFunctionNode.indexing
45
            dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape, func)
46

47
48
            args = _buildNumpyArgumentList(parameters, fullArguments)
            cache[key] = (args, dictWithBlockAndThreadNumbers)
Martin Bauer's avatar
Martin Bauer committed
49
            cacheValues.append(kwargs)  # keep objects alive such that ids remain unique
50
            func(*args, **dictWithBlockAndThreadNumbers)
51
        #cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
52
    return wrapper
53
54


55
def _buildNumpyArgumentList(parameters, argumentDict):
56
    argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
57
    result = []
58
    for arg in parameters:
59
60
61
        if arg.isFieldArgument:
            field = argumentDict[arg.fieldName]
            if arg.isFieldPtrArgument:
62
63
64
65
66
                actualType = field.dtype
                expectedType = arg.dtype.baseType.numpyDtype
                if expectedType != actualType:
                    raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
                                     (arg.fieldName, expectedType, actualType))
67
68
                result.append(field.gpudata)
            elif arg.isFieldStrideArgument:
69
70
71
                dtype = getBaseType(arg.dtype).numpyDtype
                strideArr = np.array(field.strides, dtype=dtype) // field.dtype.itemsize
                result.append(cuda.In(strideArr))
72
73
74
75
            elif arg.isFieldShapeArgument:
                dtype = getBaseType(arg.dtype).numpyDtype
                shapeArr = np.array(field.shape, dtype=dtype)
                result.append(cuda.In(shapeArr))
76
77
78
79
            else:
                assert False
        else:
            param = argumentDict[arg.name]
80
            expectedType = arg.dtype.numpyDtype
81
            result.append(expectedType.type(param))
82
    assert len(result) == len(parameters)
83
84
85
    return result


86
87
88
89
90
91
92
93
94
95
96
97
98
99
def _checkArguments(parameterSpecification, argumentDict):
    """
    Checks if parameters passed to kernel match the description in the AST function node.
    If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
    """
    argumentDict = {symbolNameToVariableName(k): v for k, v in argumentDict.items()}
    arrayShapes = set()
    indexArrShapes = set()
    for arg in parameterSpecification:
        if arg.isFieldArgument:
            try:
                fieldArr = argumentDict[arg.fieldName]
            except KeyError:
                raise KeyError("Missing field parameter for kernel call " + arg.fieldName)
Martin Bauer's avatar
Martin Bauer committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            symbolicField = arg.field
            if arg.isFieldPtrArgument:
                if symbolicField.hasFixedShape:
                    symbolicFieldShape = tuple(int(i) for i in symbolicField.shape)
                    if isinstance(symbolicField.dtype, StructType):
                        symbolicFieldShape = symbolicFieldShape[:-1]
                    if symbolicFieldShape != fieldArr.shape:
                        raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
                                         (arg.fieldName, str(fieldArr.shape), str(symbolicField.shape)))
                if symbolicField.hasFixedShape:
                    symbolicFieldStrides = tuple(int(i) * fieldArr.dtype.itemsize for i in symbolicField.strides)
                    if isinstance(symbolicField.dtype, StructType):
                        symbolicFieldStrides = symbolicFieldStrides[:-1]
                    if symbolicFieldStrides != fieldArr.strides:
                        raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                         (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides)))
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
                if symbolicField.isIndexField:
                    indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])
                else:
                    arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions])

    if len(arrayShapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(arrayShapes))
    if len(indexArrShapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(arrayShapes))

    if len(indexArrShapes) > 0:
        return list(indexArrShapes)[0]
    else:
        return list(arrayShapes)[0]
132

Martin Bauer's avatar
Martin Bauer committed
133

134