Newer
Older
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.data_types import StructType
from pystencils.field import FieldType
from pystencils.include import get_pystencils_include_path
def make_python_function(kernel_function_node, argument_dict=None):
"""
Creates a kernel function from an abstract syntax tree which
was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel`
or :func:`pystencils.gpucuda.created_indexed_cuda_kernel`
Args:
kernel_function_node: the abstract syntax tree
argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the
returned kernel functor.
Returns:
compiled kernel as Python function
from pycuda.compiler import SourceModule
header_list = ['<stdint.h>'] + list(get_headers(kernel_function_node))
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
code = includes + "\n"
code += "#define FUNC_PREFIX __global__\n"
code += "#define RESTRICT __restrict__\n\n"
code += str(generate_c(kernel_function_node, dialect='cuda'))
options = options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets", "-use_fast_math"]
if USE_FAST_MATH:
options.append("-use_fast_math")
mod = SourceModule(code, options=options, include_dirs=[get_pystencils_include_path()])
func = mod.get_function(kernel_function_node.function_name)
parameters = kernel_function_node.get_parameters()
cache = {}
key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v))
for k, v in kwargs.items()))
try:
args, block_and_thread_numbers = cache[key]
func(*args, **block_and_thread_numbers)
except KeyError:
full_arguments = argument_dict.copy()
full_arguments.update(kwargs)
shape = _check_arguments(parameters, full_arguments)
indexing = kernel_function_node.indexing
block_and_thread_numbers = indexing.call_parameters(shape)
block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid'])
args = _build_numpy_argument_list(parameters, full_arguments)
cache[key] = (args, block_and_thread_numbers)
cache_values.append(kwargs) # keep objects alive such that ids remain unique
func(*args, **block_and_thread_numbers)
# import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.parameters = kernel_function_node.get_parameters()
def _build_numpy_argument_list(parameters, argument_dict):
argument_dict = {k: v for k, v in argument_dict.items()}
for param in parameters:
if param.is_field_pointer:
array = argument_dict[param.field_name]
actual_type = array.dtype
expected_type = param.fields[0].dtype.numpy_dtype
if expected_type != actual_type:
raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
(param.field_name, expected_type, actual_type))
result.append(array)
elif param.is_field_stride:
cast_to_dtype = param.symbol.dtype.numpy_dtype.type
array = argument_dict[param.field_name]
stride = cast_to_dtype(array.strides[param.symbol.coordinate] // array.dtype.itemsize)
result.append(stride)
elif param.is_field_shape:
cast_to_dtype = param.symbol.dtype.numpy_dtype.type
array = argument_dict[param.field_name]
result.append(cast_to_dtype(array.shape[param.symbol.coordinate]))
expected_type = param.symbol.dtype.numpy_dtype
result.append(expected_type.type(argument_dict[param.symbol.name]))
assert len(result) == len(parameters)
def _check_arguments(parameter_specification, argument_dict):
"""
Checks if parameters passed to kernel match the description in the AST function node.
If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads
"""
argument_dict = {k: v for k, v in argument_dict.items()}
for param in parameter_specification:
if isinstance(param.symbol, FieldPointerSymbol):
symbolic_field = param.fields[0]
field_arr = argument_dict[symbolic_field.name]
raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))
if symbolic_field.has_fixed_shape:
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_shape = symbolic_field_shape[:-1]
if symbolic_field_shape != field_arr.shape:
raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" %
(symbolic_field.name, str(field_arr.shape), str(symbolic_field.shape)))
if symbolic_field.has_fixed_shape:
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_strides = symbolic_field_strides[:-1]
if symbolic_field_strides != field_arr.strides:
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
(symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))
if FieldType.is_indexed(symbolic_field):
index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
elif FieldType.is_generic(symbolic_field):
array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions])
if len(array_shapes) > 1:
raise ValueError("All passed arrays have to have the same size " + str(array_shapes))
if len(index_arr_shapes) > 1:
raise ValueError("All passed index arrays have to have the same size " + str(array_shapes))
if len(index_arr_shapes) > 0:
return list(index_arr_shapes)[0]