import numpy as np from pystencils.backends.cbackend import generate_c, get_headers from pystencils.kernelparameters import FieldPointerSymbol from pystencils.data_types import StructType from pystencils.field import FieldType from pystencils.include import get_pystencils_include_path USE_FAST_MATH = True def make_python_function(kernel_function_node, argument_dict=None): """ Creates a kernel function from an abstract syntax tree which was created e.g. by :func:`pystencils.gpucuda.create_cuda_kernel` or :func:`pystencils.gpucuda.created_indexed_cuda_kernel` Args: kernel_function_node: the abstract syntax tree argument_dict: parameters passed here are already fixed. Remaining parameters have to be passed to the returned kernel functor. Returns: compiled kernel as Python function """ import pycuda.autoinit # NOQA from pycuda.compiler import SourceModule if argument_dict is None: argument_dict = {} header_list = ['<stdint.h>'] + list(get_headers(kernel_function_node)) includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list]) code = includes + "\n" code += "#define FUNC_PREFIX __global__\n" code += "#define RESTRICT __restrict__\n\n" code += str(generate_c(kernel_function_node, dialect='cuda')) options = options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets", "-use_fast_math"] if USE_FAST_MATH: options.append("-use_fast_math") mod = SourceModule(code, options=options, include_dirs=[get_pystencils_include_path()]) func = mod.get_function(kernel_function_node.function_name) parameters = kernel_function_node.get_parameters() cache = {} cache_values = [] def wrapper(**kwargs): key = hash(tuple((k, v.ctypes.data, v.strides, v.shape) if isinstance(v, np.ndarray) else (k, id(v)) for k, v in kwargs.items())) try: args, block_and_thread_numbers = cache[key] func(*args, **block_and_thread_numbers) except KeyError: full_arguments = argument_dict.copy() full_arguments.update(kwargs) shape = _check_arguments(parameters, full_arguments) indexing = kernel_function_node.indexing block_and_thread_numbers = indexing.call_parameters(shape) block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block']) block_and_thread_numbers['grid'] = tuple(int(i) for i in block_and_thread_numbers['grid']) args = _build_numpy_argument_list(parameters, full_arguments) cache[key] = (args, block_and_thread_numbers) cache_values.append(kwargs) # keep objects alive such that ids remain unique func(*args, **block_and_thread_numbers) # import pycuda.driver as cuda # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called wrapper.ast = kernel_function_node wrapper.parameters = kernel_function_node.get_parameters() wrapper.num_regs = func.num_regs return wrapper def _build_numpy_argument_list(parameters, argument_dict): argument_dict = {k: v for k, v in argument_dict.items()} result = [] for param in parameters: if param.is_field_pointer: array = argument_dict[param.field_name] actual_type = array.dtype expected_type = param.fields[0].dtype.numpy_dtype if expected_type != actual_type: raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." % (param.field_name, expected_type, actual_type)) result.append(array) elif param.is_field_stride: cast_to_dtype = param.symbol.dtype.numpy_dtype.type array = argument_dict[param.field_name] stride = cast_to_dtype(array.strides[param.symbol.coordinate] // array.dtype.itemsize) result.append(stride) elif param.is_field_shape: cast_to_dtype = param.symbol.dtype.numpy_dtype.type array = argument_dict[param.field_name] result.append(cast_to_dtype(array.shape[param.symbol.coordinate])) else: expected_type = param.symbol.dtype.numpy_dtype result.append(expected_type.type(argument_dict[param.symbol.name])) assert len(result) == len(parameters) return result def _check_arguments(parameter_specification, argument_dict): """ Checks if parameters passed to kernel match the description in the AST function node. If not it raises a ValueError, on success it returns the array shape that determines the CUDA blocks and threads """ argument_dict = {k: v for k, v in argument_dict.items()} array_shapes = set() index_arr_shapes = set() for param in parameter_specification: if isinstance(param.symbol, FieldPointerSymbol): symbolic_field = param.fields[0] try: field_arr = argument_dict[symbolic_field.name] except KeyError: raise KeyError("Missing field parameter for kernel call " + str(symbolic_field)) if symbolic_field.has_fixed_shape: symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape) if isinstance(symbolic_field.dtype, StructType): symbolic_field_shape = symbolic_field_shape[:-1] if symbolic_field_shape != field_arr.shape: raise ValueError("Passed array '%s' has shape %s which does not match expected shape %s" % (symbolic_field.name, str(field_arr.shape), str(symbolic_field.shape))) if symbolic_field.has_fixed_shape: symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides) if isinstance(symbolic_field.dtype, StructType): symbolic_field_strides = symbolic_field_strides[:-1] if symbolic_field_strides != field_arr.strides: raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % (symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides))) if FieldType.is_indexed(symbolic_field): index_arr_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions]) elif FieldType.is_generic(symbolic_field): array_shapes.add(field_arr.shape[:symbolic_field.spatial_dimensions]) if len(array_shapes) > 1: raise ValueError("All passed arrays have to have the same size " + str(array_shapes)) if len(index_arr_shapes) > 1: raise ValueError("All passed index arrays have to have the same size " + str(array_shapes)) if len(index_arr_shapes) > 0: return list(index_arr_shapes)[0] else: return list(array_shapes)[0]