diff --git a/pystencils/cache.py b/pystencils/cache.py index 5df15ae7c7498e6e849c93f2f071435560a2c415..41810a583f479232670300022af40751663e0769 100644 --- a/pystencils/cache.py +++ b/pystencils/cache.py @@ -1,5 +1,5 @@ import os -from collections import Hashable +from collections.abc import Hashable from functools import partial from itertools import chain diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 99ce2d6fbfbb31f87fdc90aec22fa203386935fe..018100eb24263b2ae951fad1fefdad5520d307ea 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir from pystencils import FieldType from pystencils.backends.cbackend import generate_c, get_headers from pystencils.include import get_pystencils_include_path +from pystencils.kernel_wrapper import KernelWrapper from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update @@ -482,16 +483,6 @@ class ExtensionModuleCode: print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) -class KernelWrapper: - def __init__(self, kernel, parameters, ast_node): - self.kernel = kernel - self.parameters = parameters - self.ast = ast_node - - def __call__(self, **kwargs): - return self.kernel(**kwargs) - - def compile_module(code, code_hash, base_dir): compiler_config = get_compiler_config() extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] diff --git a/pystencils/display_utils.py b/pystencils/display_utils.py index 8cdaa4820444cedd1c4cbf8f2db7d7391e3e6344..638d1290acbbfc4d86bec12028dc59b37e2f98ea 100644 --- a/pystencils/display_utils.py +++ b/pystencils/display_utils.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional import sympy as sp from pystencils.astnodes import KernelFunction +from pystencils.kernel_wrapper import KernelWrapper def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): @@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None): Can either be displayed as HTML in Jupyter notebooks or printed as normal string. """ from pystencils.backends.cbackend import generate_c + + if isinstance(ast, KernelWrapper): + ast = ast.ast + dialect = 'cuda' if ast.backend == 'gpucuda' else 'c' class CodeDisplay: diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index e38290338236c16231070e13d11a658a5223bc71..28ec47d0ba22e92cd4064b3400a0d055a87be7cb 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -6,6 +6,7 @@ from pystencils.field import FieldType from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.interpolation_astnodes import TextureAccess +from pystencils.kernel_wrapper import KernelWrapper from pystencils.kernelparameters import FieldPointerSymbol USE_FAST_MATH = True @@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen func(*args, **block_and_thread_numbers) # import pycuda.driver as cuda # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called - wrapper.ast = kernel_function_node - wrapper.parameters = kernel_function_node.get_parameters() + ast = kernel_function_node + parameters = kernel_function_node.get_parameters() + wrapper = KernelWrapper(wrapper, parameters, ast) wrapper.num_regs = func.num_regs return wrapper diff --git a/pystencils/kernel_wrapper.py b/pystencils/kernel_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..0e327711e5a355219cc2664ac9a6c8a02d88bc09 --- /dev/null +++ b/pystencils/kernel_wrapper.py @@ -0,0 +1,19 @@ +""" +Light-weight wrapper around a compiled kernel +""" +import pystencils + + +class KernelWrapper: + def __init__(self, kernel, parameters, ast_node): + self.kernel = kernel + self.parameters = parameters + self.ast = ast_node + self.num_regs = None + + def __call__(self, **kwargs): + return self.kernel(**kwargs) + + @property + def code(self): + return str(pystencils.show_code(self.ast))