Commit afe8c892 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use KernelWrapper also for in the gpucuda backend

parent a2c6f9f6
...@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir ...@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
...@@ -482,16 +483,6 @@ class ExtensionModuleCode: ...@@ -482,16 +483,6 @@ class ExtensionModuleCode:
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
def __call__(self, **kwargs):
return self.kernel(**kwargs)
def compile_module(code, code_hash, base_dir): def compile_module(code, code_hash, base_dir):
compiler_config = get_compiler_config() compiler_config = get_compiler_config()
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
......
...@@ -6,6 +6,7 @@ from pystencils.field import FieldType ...@@ -6,6 +6,7 @@ from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess from pystencils.interpolation_astnodes import TextureAccess
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True USE_FAST_MATH = True
...@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen ...@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
func(*args, **block_and_thread_numbers) func(*args, **block_and_thread_numbers)
# import pycuda.driver as cuda # import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node ast = kernel_function_node
wrapper.parameters = kernel_function_node.get_parameters() parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs wrapper.num_regs = func.num_regs
return wrapper return wrapper
......
"""
Light-weight wrapper around a compiled kernel
"""
import pystencils
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
self.num_regs = None
def __call__(self, **kwargs):
return self.kernel(**kwargs)
@property
def code(self):
return str(pystencils.show_code(self.ast))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment