diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 99ce2d6fbfbb31f87fdc90aec22fa203386935fe..018100eb24263b2ae951fad1fefdad5520d307ea 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir from pystencils import FieldType from pystencils.backends.cbackend import generate_c, get_headers from pystencils.include import get_pystencils_include_path +from pystencils.kernel_wrapper import KernelWrapper from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update @@ -482,16 +483,6 @@ class ExtensionModuleCode: print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) -class KernelWrapper: - def __init__(self, kernel, parameters, ast_node): - self.kernel = kernel - self.parameters = parameters - self.ast = ast_node - - def __call__(self, **kwargs): - return self.kernel(**kwargs) - - def compile_module(code, code_hash, base_dir): compiler_config = get_compiler_config() extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index e38290338236c16231070e13d11a658a5223bc71..28ec47d0ba22e92cd4064b3400a0d055a87be7cb 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -6,6 +6,7 @@ from pystencils.field import FieldType from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.interpolation_astnodes import TextureAccess +from pystencils.kernel_wrapper import KernelWrapper from pystencils.kernelparameters import FieldPointerSymbol USE_FAST_MATH = True @@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen func(*args, **block_and_thread_numbers) # import pycuda.driver as cuda # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called - wrapper.ast = kernel_function_node - wrapper.parameters = kernel_function_node.get_parameters() + ast = kernel_function_node + parameters = kernel_function_node.get_parameters() + wrapper = KernelWrapper(wrapper, parameters, ast) wrapper.num_regs = func.num_regs return wrapper diff --git a/pystencils/kernel_wrapper.py b/pystencils/kernel_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..0e327711e5a355219cc2664ac9a6c8a02d88bc09 --- /dev/null +++ b/pystencils/kernel_wrapper.py @@ -0,0 +1,19 @@ +""" +Light-weight wrapper around a compiled kernel +""" +import pystencils + + +class KernelWrapper: + def __init__(self, kernel, parameters, ast_node): + self.kernel = kernel + self.parameters = parameters + self.ast = ast_node + self.num_regs = None + + def __call__(self, **kwargs): + return self.kernel(**kwargs) + + @property + def code(self): + return str(pystencils.show_code(self.ast))