cudajit.py 9 KB
Newer Older
1
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
2

3
from pystencils.backends.cbackend import generate_c, get_headers
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.data_types import StructType
5
from pystencils.field import FieldType
6
7
from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
Stephan Seitz's avatar
Stephan Seitz committed
8
from pystencils.interpolation_astnodes import InterpolatorAccess, TextureCachedField
9
from pystencils.kernel_wrapper import KernelWrapper
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.kernelparameters import FieldPointerSymbol
11

12
13
14
USE_FAST_MATH = True


Stephan Seitz's avatar
Stephan Seitz committed
15
16
17
18
19
20
21
def get_cubic_interpolation_include_paths():
    from os.path import join, dirname

    return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
            join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]


22
def make_python_function(kernel_function_node, argument_dict=None, custom_backend=None):
23
24
    """
    Creates a kernel function from an abstract syntax tree which
Martin Bauer's avatar
Martin Bauer committed
25
26
    was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
    or :func:`pystencils.gpucuda.created_indexed_cuda_kernel`
27

Martin Bauer's avatar
Martin Bauer committed
28
29
30
31
32
33
34
    Args:
        kernel_function_node: the abstract syntax tree
        argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
                       returned kernel functor.

    Returns:
        compiled kernel as Python function
35
    """
Martin Bauer's avatar
Martin Bauer committed
36
    import pycuda.autoinit  # NOQA
37
38
    from pycuda.compiler import SourceModule

Martin Bauer's avatar
Martin Bauer committed
39
40
41
    if argument_dict is None:
        argument_dict = {}

42
    header_list = ['<cstdint>'] + list(get_headers(kernel_function_node))
43
44
45
    includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])

    code = includes + "\n"
46
47
    code += "#define FUNC_PREFIX __global__\n"
    code += "#define RESTRICT __restrict__\n\n"
48
    code += str(generate_c(kernel_function_node, dialect='cuda', custom_backend=custom_backend))
Stephan Seitz's avatar
Stephan Seitz committed
49
50
    textures = set(d.interpolator for d in kernel_function_node.atoms(
        InterpolatorAccess) if isinstance(d.interpolator, TextureCachedField))
51
52

    nvcc_options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
53
    if USE_FAST_MATH:
54
55
        nvcc_options.append("-use_fast_math")

56
57
58
59
60
    # Code for CubicInterpolationCUDA
    from pystencils.interpolation_astnodes import InterpolationMode
    from os.path import join, dirname, isdir

    if any(t.interpolation_mode == InterpolationMode.CUBIC_SPLINE for t in textures):
61
62
63
        assert isdir(join(dirname(__file__), ("CubicInterpolationCUDA", "code")),
                     "Submodule CubicInterpolationCUDA does not exist.\n"
                     + "Clone https://github.com/theHamsta/CubicInterpolationCUDA into pystencils.gpucuda")
64
65
66
67
68
69
70
        nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")]
        nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]

        needed_dims = set(t.field.spatial_dimensions for t in textures
                          if t.interpolation_mode == InterpolationMode.CUBIC_SPLINE)
        for i in needed_dims:
            code = 'extern "C++" {\n#include "cubicTex%iD.cu"\n}\n' % i + code
71
72
73

    mod = SourceModule(code, options=nvcc_options, include_dirs=[
                       get_pystencils_include_path(), get_pycuda_include_path()])
Martin Bauer's avatar
Martin Bauer committed
74
    func = mod.get_function(kernel_function_node.function_name)
75

76
    parameters = kernel_function_node.get_parameters()
77

78
    cache = {}
Martin Bauer's avatar
Martin Bauer committed
79
    cache_values = []
80

81
    def wrapper(**kwargs):
82
83
        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
                         for k, v in kwargs.items()))
84
        try:
Martin Bauer's avatar
Martin Bauer committed
85
86
            args, block_and_thread_numbers = cache[key]
            func(*args, **block_and_thread_numbers)
87
        except KeyError:
Martin Bauer's avatar
Martin Bauer committed
88
89
90
91
92
93
94
95
96
            full_arguments = argument_dict.copy()
            full_arguments.update(kwargs)
            shape = _check_arguments(parameters, full_arguments)

            indexing = kernel_function_node.indexing
            block_and_thread_numbers = indexing.call_parameters(shape)
            block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
            block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid'])

97
98
99
100
101
102
            # TODO: use texture objects:
            # https://devblogs.nvidia.com/cuda-pro-tip-kepler-texture-objects-improve-performance-and-flexibility/
            for tex in textures:
                tex_ref = mod.get_texref(str(tex))
                ndarray_to_tex(tex_ref, full_arguments[tex.field.name], tex.address_mode,
                               tex.filter_mode, tex.use_normalized_coordinates, tex.read_as_integer)
Martin Bauer's avatar
Martin Bauer committed
103
104
105
106
            args = _build_numpy_argument_list(parameters, full_arguments)
            cache[key] = (args, block_and_thread_numbers)
            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
            func(*args, **block_and_thread_numbers)
107
        # import pycuda.driver as cuda
Martin Bauer's avatar
Martin Bauer committed
108
        # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
109
110
111
    ast = kernel_function_node
    parameters = kernel_function_node.get_parameters()
    wrapper = KernelWrapper(wrapper, parameters, ast)
Martin Bauer's avatar
Martin Bauer committed
112
    wrapper.num_regs = func.num_regs
113
    return wrapper
114
115


Martin Bauer's avatar
Martin Bauer committed
116
def _build_numpy_argument_list(parameters, argument_dict):
117
    argument_dict = {k: v for k, v in argument_dict.items()}
118
    result = []
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    for param in parameters:
        if param.is_field_pointer:
            array = argument_dict[param.field_name]
            actual_type = array.dtype
            expected_type = param.fields[0].dtype.numpy_dtype
            if expected_type != actual_type:
                raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
                                 (param.field_name, expected_type, actual_type))
            result.append(array)
        elif param.is_field_stride:
            cast_to_dtype = param.symbol.dtype.numpy_dtype.type
            array = argument_dict[param.field_name]
            stride = cast_to_dtype(array.strides[param.symbol.coordinate] // array.dtype.itemsize)
            result.append(stride)
        elif param.is_field_shape:
            cast_to_dtype = param.symbol.dtype.numpy_dtype.type
            array = argument_dict[param.field_name]
            result.append(cast_to_dtype(array.shape[param.symbol.coordinate]))
138
        else:
139
140
141
            expected_type = param.symbol.dtype.numpy_dtype
            result.append(expected_type.type(argument_dict[param.symbol.name]))

142
    assert len(result) == len(parameters)
143
144
145
    return result


Martin Bauer's avatar
Martin Bauer committed
146
def _check_arguments(parameter_specification, argument_dict):
147
148
149
150
    """
    Checks if parameters passed to kernel match the description in the AST function node.
    If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
    """
151
    argument_dict = {k: v for k, v in argument_dict.items()}
Martin Bauer's avatar
Martin Bauer committed
152
153
    array_shapes = set()
    index_arr_shapes = set()
154
155
156
157
158

    for param in parameter_specification:
        if isinstance(param.symbol, FieldPointerSymbol):
            symbolic_field = param.fields[0]

159
            try:
160
                field_arr = argument_dict[symbolic_field.name]
161
            except KeyError:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))

            if symbolic_field.has_fixed_shape:
                symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
                if isinstance(symbolic_field.dtype, StructType):
                    symbolic_field_shape = symbolic_field_shape[:-1]
                if symbolic_field_shape != field_arr.shape:
                    raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
                                     (symbolic_field.name, str(field_arr.shape), str(symbolic_field.shape)))
            if symbolic_field.has_fixed_shape:
                symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
                if isinstance(symbolic_field.dtype, StructType):
                    symbolic_field_strides = symbolic_field_strides[:-1]
                if symbolic_field_strides != field_arr.strides:
                    raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
                                     (symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))

            if FieldType.is_indexed(symbolic_field):
                index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
            elif FieldType.is_generic(symbolic_field):
                array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
183

Martin Bauer's avatar
Martin Bauer committed
184
185
186
187
    if len(array_shapes) > 1:
        raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
    if len(index_arr_shapes) > 1:
        raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
188

Martin Bauer's avatar
Martin Bauer committed
189
190
    if len(index_arr_shapes) > 0:
        return list(index_arr_shapes)[0]
191
    else:
Martin Bauer's avatar
Martin Bauer committed
192
        return list(array_shapes)[0]