opencljit.py 3.31 KB
Newer Older
Stephan Seitz's avatar
Stephan Seitz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np

from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.gpucuda.cudajit import _build_numpy_argument_list, _check_arguments
from pystencils.include import get_pystencils_include_path

USE_FAST_MATH = True


def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argument_dict=None, custom_backend=None):
    """
    Creates a kernel function from an abstract syntax tree which
    was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
    or :func:`pystencils.gpucuda.created_indexed_cuda_kernel`

    Args:
        kernel_function_node: the abstract syntax tree
        argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
                       returned kernel functor.

    Returns:
        compiled kernel as Python function
    """
    import pyopencl as cl

    if argument_dict is None:
        argument_dict = {}

    kernel_function_node.function_name = "opencl_" + kernel_function_node.function_name
    header_list = ['"opencl_stdint.h"'] + list(get_headers(kernel_function_node))
    includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])

    code = includes + "\n"
    code += "#define FUNC_PREFIX __kernel\n"
    code += "#define RESTRICT restrict\n\n"
    code += str(generate_c(kernel_function_node, dialect='opencl', custom_backend=custom_backend))
    options = []
    if USE_FAST_MATH:
        options.append("-cl-unsafe-math-optimizations -cl-mad-enable -cl-fast-relaxed-math -cl-finite-math-only")
    options.append("-I \"" + get_pystencils_include_path() + "\"")
    mod = cl.Program(opencl_ctx, code).build(options=options)
    func = getattr(mod, kernel_function_node.function_name)

    parameters = kernel_function_node.get_parameters()

    cache = {}
    cache_values = []

    def wrapper(**kwargs):
        key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
                         for k, v in kwargs.items()))
        try:
            args, block_and_thread_numbers = cache[key]
            func(opencl_queue, block_and_thread_numbers['grid'], block_and_thread_numbers['block'], *args)
        except KeyError:
            full_arguments = argument_dict.copy()
            full_arguments.update(kwargs)
            shape = _check_arguments(parameters, full_arguments)

            indexing = kernel_function_node.indexing
            block_and_thread_numbers = indexing.call_parameters(shape)
            block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
63
64
            block_and_thread_numbers['grid'] = tuple(int(b * g) for (b, g) in zip(block_and_thread_numbers['block'],
                                                                                  block_and_thread_numbers['grid']))
Stephan Seitz's avatar
Stephan Seitz committed
65
66
67
68
69
70
71
72
73
74

            args = _build_numpy_argument_list(parameters, full_arguments)
            args = [a.data for a in args if hasattr(a, 'data')]
            cache[key] = (args, block_and_thread_numbers)
            cache_values.append(kwargs)  # keep objects alive such that ids remain unique
            func(opencl_queue, block_and_thread_numbers['grid'], block_and_thread_numbers['block'], *args)

    wrapper.ast = kernel_function_node
    wrapper.parameters = kernel_function_node.get_parameters()
    return wrapper