cudajit.py 6.88 KB
Newer Older
1
import numpy as np
2
from pystencils.backends.cbackend import generate_c, get_headers
Martin Bauer's avatar
Martin Bauer committed
3
4
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.data_types import StructType
5
from pystencils.field import FieldType
6
from pystencils.include import get_pystencils_include_path
7
8


9
10
11
USE_FAST_MATH = True


Martin Bauer's avatar
Martin Bauer committed
12
def make_python_function(kernel_function_node, argument_dict=None):
13
14
    """
    Creates a kernel function from an abstract syntax tree which
Martin Bauer's avatar
Martin Bauer committed
15
16
    was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
    or :func:`pystencils.gpucuda.created_indexed_cuda_kernel`
17

Martin Bauer's avatar
Martin Bauer committed
18
19
20
21
22
23
24
    Args:
        kernel_function_node: the abstract syntax tree
        argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
                       returned kernel functor.

    Returns:
        compiled kernel as Python function
25
    """
Martin Bauer's avatar
Martin Bauer committed
26
    import pycuda.autoinit  # NOQA
27
28
    from pycuda.compiler import SourceModule

Martin Bauer's avatar
Martin Bauer committed
29
30
31
    if argument_dict is None:
        argument_dict = {}

32
33
34
35
    header_list = ['<stdint.h>'] + list(get_headers(kernel_function_node))
    includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])

    code = includes + "\n"
36
37
    code += "#define FUNC_PREFIX __global__\n"
    code += "#define RESTRICT __restrict__\n\n"
38
    code += str(generate_c(kernel_function_node, dialect='cuda'))
39
    options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
40
41
42
    if USE_FAST_MATH:
        options.append("-use_fast_math")
    mod = SourceModule(code, options=options, include_dirs=[get_pystencils_include_path()])
Martin Bauer's avatar
Martin Bauer committed
43
    func = mod.get_function(kernel_function_node.function_name)
44

45
    parameters = kernel_function_node.get_parameters()
46

47
    cache = {}
Martin Bauer's avatar
Martin Bauer committed
48
    cache_values = []
49

50
    def wrapper(**kwargs):
51
52
        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
                         for k, v in kwargs.items()))
53
        try:
Martin Bauer's avatar
Martin Bauer committed
54
55
            args, block_and_thread_numbers = cache[key]
            func(*args, **block_and_thread_numbers)
56
        except KeyError:
Martin Bauer's avatar
Martin Bauer committed
57
58
59
60
61
62
63
64
65
66
67
68
69
            full_arguments = argument_dict.copy()
            full_arguments.update(kwargs)
            shape = _check_arguments(parameters, full_arguments)

            indexing = kernel_function_node.indexing
            block_and_thread_numbers = indexing.call_parameters(shape)
            block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
            block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid'])

            args = _build_numpy_argument_list(parameters, full_arguments)
            cache[key] = (args, block_and_thread_numbers)
            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
            func(*args, **block_and_thread_numbers)
70
        # import pycuda.driver as cuda
Martin Bauer's avatar
Martin Bauer committed
71
        # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
Martin Bauer's avatar
Martin Bauer committed
72
    wrapper.ast = kernel_function_node
73
    wrapper.parameters = kernel_function_node.get_parameters()
Martin Bauer's avatar
Martin Bauer committed
74
    wrapper.num_regs = func.num_regs
75
    return wrapper
76
77


Martin Bauer's avatar
Martin Bauer committed
78
def _build_numpy_argument_list(parameters, argument_dict):
79
    argument_dict = {k: v for k, v in argument_dict.items()}
80
    result = []
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    for param in parameters:
        if param.is_field_pointer:
            array = argument_dict[param.field_name]
            actual_type = array.dtype
            expected_type = param.fields[0].dtype.numpy_dtype
            if expected_type != actual_type:
                raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
                                 (param.field_name, expected_type, actual_type))
            result.append(array)
        elif param.is_field_stride:
            cast_to_dtype = param.symbol.dtype.numpy_dtype.type
            array = argument_dict[param.field_name]
            stride = cast_to_dtype(array.strides[param.symbol.coordinate] // array.dtype.itemsize)
            result.append(stride)
        elif param.is_field_shape:
            cast_to_dtype = param.symbol.dtype.numpy_dtype.type
            array = argument_dict[param.field_name]
            result.append(cast_to_dtype(array.shape[param.symbol.coordinate]))
100
        else:
101
102
103
            expected_type = param.symbol.dtype.numpy_dtype
            result.append(expected_type.type(argument_dict[param.symbol.name]))

104
    assert len(result) == len(parameters)
105
106
107
    return result


Martin Bauer's avatar
Martin Bauer committed
108
def _check_arguments(parameter_specification, argument_dict):
109
110
111
112
    """
    Checks if parameters passed to kernel match the description in the AST function node.
    If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
    """
113
    argument_dict = {k: v for k, v in argument_dict.items()}
Martin Bauer's avatar
Martin Bauer committed
114
115
    array_shapes = set()
    index_arr_shapes = set()
116
117
118
119
120

    for param in parameter_specification:
        if isinstance(param.symbol, FieldPointerSymbol):
            symbolic_field = param.fields[0]

121
            try:
122
                field_arr = argument_dict[symbolic_field.name]
123
            except KeyError:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))

            if symbolic_field.has_fixed_shape:
                symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
                if isinstance(symbolic_field.dtype, StructType):
                    symbolic_field_shape = symbolic_field_shape[:-1]
                if symbolic_field_shape != field_arr.shape:
                    raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
                                     (symbolic_field.name, str(field_arr.shape), str(symbolic_field.shape)))
            if symbolic_field.has_fixed_shape:
                symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
                if isinstance(symbolic_field.dtype, StructType):
                    symbolic_field_strides = symbolic_field_strides[:-1]
                if symbolic_field_strides != field_arr.strides:
                    raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                     (symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))

            if FieldType.is_indexed(symbolic_field):
                index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
            elif FieldType.is_generic(symbolic_field):
                array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
145

Martin Bauer's avatar
Martin Bauer committed
146
147
148
149
    if len(array_shapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
    if len(index_arr_shapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
150

Martin Bauer's avatar
Martin Bauer committed
151
152
    if len(index_arr_shapes) > 0:
        return list(index_arr_shapes)[0]
153
    else:
Martin Bauer's avatar
Martin Bauer committed
154
        return list(array_shapes)[0]