From afe8c8925f92eaea925cef9c28a006fffa987e37 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 26 Sep 2019 16:46:03 +0200
Subject: [PATCH] Use KernelWrapper also for in the gpucuda backend

---
 pystencils/cpu/cpujit.py      | 11 +----------
 pystencils/gpucuda/cudajit.py |  6 ++++--
 pystencils/kernel_wrapper.py  | 19 +++++++++++++++++++
 3 files changed, 24 insertions(+), 12 deletions(-)
 create mode 100644 pystencils/kernel_wrapper.py

diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 99ce2d6fb..018100eb2 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
 from pystencils import FieldType
 from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.include import get_pystencils_include_path
+from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
 
 
@@ -482,16 +483,6 @@ class ExtensionModuleCode:
         print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
 
 
-class KernelWrapper:
-    def __init__(self, kernel, parameters, ast_node):
-        self.kernel = kernel
-        self.parameters = parameters
-        self.ast = ast_node
-
-    def __call__(self, **kwargs):
-        return self.kernel(**kwargs)
-
-
 def compile_module(code, code_hash, base_dir):
     compiler_config = get_compiler_config()
     extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index e38290338..28ec47d0b 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -6,6 +6,7 @@ from pystencils.field import FieldType
 from pystencils.gpucuda.texture_utils import ndarray_to_tex
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
 from pystencils.interpolation_astnodes import TextureAccess
+from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.kernelparameters import FieldPointerSymbol
 
 USE_FAST_MATH = True
@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
             func(*args, **block_and_thread_numbers)
         # import pycuda.driver as cuda
         # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
-    wrapper.ast = kernel_function_node
-    wrapper.parameters = kernel_function_node.get_parameters()
+    ast = kernel_function_node
+    parameters = kernel_function_node.get_parameters()
+    wrapper = KernelWrapper(wrapper, parameters, ast)
     wrapper.num_regs = func.num_regs
     return wrapper
 
diff --git a/pystencils/kernel_wrapper.py b/pystencils/kernel_wrapper.py
new file mode 100644
index 000000000..0e327711e
--- /dev/null
+++ b/pystencils/kernel_wrapper.py
@@ -0,0 +1,19 @@
+"""
+Light-weight wrapper around a compiled kernel
+"""
+import pystencils
+
+
+class KernelWrapper:
+    def __init__(self, kernel, parameters, ast_node):
+        self.kernel = kernel
+        self.parameters = parameters
+        self.ast = ast_node
+        self.num_regs = None
+
+    def __call__(self, **kwargs):
+        return self.kernel(**kwargs)
+
+    @property
+    def code(self):
+        return str(pystencils.show_code(self.ast))
-- 
GitLab