diff --git a/pystencils/cache.py b/pystencils/cache.py
index 5df15ae7c7498e6e849c93f2f071435560a2c415..41810a583f479232670300022af40751663e0769 100644
--- a/pystencils/cache.py
+++ b/pystencils/cache.py
@@ -1,5 +1,5 @@
 import os
-from collections import Hashable
+from collections.abc import Hashable
 from functools import partial
 from itertools import chain
 
diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 99ce2d6fbfbb31f87fdc90aec22fa203386935fe..018100eb24263b2ae951fad1fefdad5520d307ea 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
 from pystencils import FieldType
 from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.include import get_pystencils_include_path
+from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
 
 
@@ -482,16 +483,6 @@ class ExtensionModuleCode:
         print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
 
 
-class KernelWrapper:
-    def __init__(self, kernel, parameters, ast_node):
-        self.kernel = kernel
-        self.parameters = parameters
-        self.ast = ast_node
-
-    def __call__(self, **kwargs):
-        return self.kernel(**kwargs)
-
-
 def compile_module(code, code_hash, base_dir):
     compiler_config = get_compiler_config()
     extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
diff --git a/pystencils/display_utils.py b/pystencils/display_utils.py
index 8cdaa4820444cedd1c4cbf8f2db7d7391e3e6344..638d1290acbbfc4d86bec12028dc59b37e2f98ea 100644
--- a/pystencils/display_utils.py
+++ b/pystencils/display_utils.py
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
 import sympy as sp
 
 from pystencils.astnodes import KernelFunction
+from pystencils.kernel_wrapper import KernelWrapper
 
 
 def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
@@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None):
     Can either  be displayed as HTML in Jupyter notebooks or printed as normal string.
     """
     from pystencils.backends.cbackend import generate_c
+
+    if isinstance(ast, KernelWrapper):
+        ast = ast.ast
+
     dialect = 'cuda' if ast.backend == 'gpucuda' else 'c'
 
     class CodeDisplay:
diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py
index e38290338236c16231070e13d11a658a5223bc71..28ec47d0ba22e92cd4064b3400a0d055a87be7cb 100644
--- a/pystencils/gpucuda/cudajit.py
+++ b/pystencils/gpucuda/cudajit.py
@@ -6,6 +6,7 @@ from pystencils.field import FieldType
 from pystencils.gpucuda.texture_utils import ndarray_to_tex
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
 from pystencils.interpolation_astnodes import TextureAccess
+from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.kernelparameters import FieldPointerSymbol
 
 USE_FAST_MATH = True
@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
             func(*args, **block_and_thread_numbers)
         # import pycuda.driver as cuda
         # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
-    wrapper.ast = kernel_function_node
-    wrapper.parameters = kernel_function_node.get_parameters()
+    ast = kernel_function_node
+    parameters = kernel_function_node.get_parameters()
+    wrapper = KernelWrapper(wrapper, parameters, ast)
     wrapper.num_regs = func.num_regs
     return wrapper
 
diff --git a/pystencils/kernel_wrapper.py b/pystencils/kernel_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e327711e5a355219cc2664ac9a6c8a02d88bc09
--- /dev/null
+++ b/pystencils/kernel_wrapper.py
@@ -0,0 +1,19 @@
+"""
+Light-weight wrapper around a compiled kernel
+"""
+import pystencils
+
+
+class KernelWrapper:
+    def __init__(self, kernel, parameters, ast_node):
+        self.kernel = kernel
+        self.parameters = parameters
+        self.ast = ast_node
+        self.num_regs = None
+
+    def __call__(self, **kwargs):
+        return self.kernel(**kwargs)
+
+    @property
+    def code(self):
+        return str(pystencils.show_code(self.ast))