Commit 0ff425f3 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'KernelWrapper' into 'master'

Kernel wrapper

See merge request pycodegen/pystencils!61
parents 9344cb42 aff0c6b8
import os
from collections import Hashable
from collections.abc import Hashable
from functools import partial
from itertools import chain
......
......@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
......@@ -482,16 +483,6 @@ class ExtensionModuleCode:
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
def __call__(self, **kwargs):
return self.kernel(**kwargs)
def compile_module(code, code_hash, base_dir):
compiler_config = get_compiler_config()
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
......
......@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import sympy as sp
from pystencils.astnodes import KernelFunction
from pystencils.kernel_wrapper import KernelWrapper
def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
......@@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
"""
from pystencils.backends.cbackend import generate_c
if isinstance(ast, KernelWrapper):
ast = ast.ast
dialect = 'cuda' if ast.backend == 'gpucuda' else 'c'
class CodeDisplay:
......
......@@ -6,6 +6,7 @@ from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True
......@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
func(*args, **block_and_thread_numbers)
# import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node
wrapper.parameters = kernel_function_node.get_parameters()
ast = kernel_function_node
parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs
return wrapper
......
"""
Light-weight wrapper around a compiled kernel
"""
import pystencils
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
self.num_regs = None
def __call__(self, **kwargs):
return self.kernel(**kwargs)
@property
def code(self):
return str(pystencils.show_code(self.ast))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment