From afe8c8925f92eaea925cef9c28a006fffa987e37 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 26 Sep 2019 16:46:03 +0200 Subject: [PATCH] Use KernelWrapper also for in the gpucuda backend --- pystencils/cpu/cpujit.py | 11 +---------- pystencils/gpucuda/cudajit.py | 6 ++++-- pystencils/kernel_wrapper.py | 19 +++++++++++++++++++ 3 files changed, 24 insertions(+), 12 deletions(-) create mode 100644 pystencils/kernel_wrapper.py diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 99ce2d6fb..018100eb2 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir from pystencils import FieldType from pystencils.backends.cbackend import generate_c, get_headers from pystencils.include import get_pystencils_include_path +from pystencils.kernel_wrapper import KernelWrapper from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update @@ -482,16 +483,6 @@ class ExtensionModuleCode: print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) -class KernelWrapper: - def __init__(self, kernel, parameters, ast_node): - self.kernel = kernel - self.parameters = parameters - self.ast = ast_node - - def __call__(self, **kwargs): - return self.kernel(**kwargs) - - def compile_module(code, code_hash, base_dir): compiler_config = get_compiler_config() extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index e38290338..28ec47d0b 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -6,6 +6,7 @@ from pystencils.field import FieldType from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.interpolation_astnodes import TextureAccess +from pystencils.kernel_wrapper import KernelWrapper from pystencils.kernelparameters import FieldPointerSymbol USE_FAST_MATH = True @@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen func(*args, **block_and_thread_numbers) # import pycuda.driver as cuda # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called - wrapper.ast = kernel_function_node - wrapper.parameters = kernel_function_node.get_parameters() + ast = kernel_function_node + parameters = kernel_function_node.get_parameters() + wrapper = KernelWrapper(wrapper, parameters, ast) wrapper.num_regs = func.num_regs return wrapper diff --git a/pystencils/kernel_wrapper.py b/pystencils/kernel_wrapper.py new file mode 100644 index 000000000..0e327711e --- /dev/null +++ b/pystencils/kernel_wrapper.py @@ -0,0 +1,19 @@ +""" +Light-weight wrapper around a compiled kernel +""" +import pystencils + + +class KernelWrapper: + def __init__(self, kernel, parameters, ast_node): + self.kernel = kernel + self.parameters = parameters + self.ast = ast_node + self.num_regs = None + + def __call__(self, **kwargs): + return self.kernel(**kwargs) + + @property + def code(self): + return str(pystencils.show_code(self.ast)) -- GitLab