Skip to content
Snippets Groups Projects
Commit afe8c892 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use KernelWrapper also for in the gpucuda backend

parent a2c6f9f6
No related merge requests found
...@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir ...@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
...@@ -482,16 +483,6 @@ class ExtensionModuleCode: ...@@ -482,16 +483,6 @@ class ExtensionModuleCode:
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
def __call__(self, **kwargs):
return self.kernel(**kwargs)
def compile_module(code, code_hash, base_dir): def compile_module(code, code_hash, base_dir):
compiler_config = get_compiler_config() compiler_config = get_compiler_config()
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
......
...@@ -6,6 +6,7 @@ from pystencils.field import FieldType ...@@ -6,6 +6,7 @@ from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess from pystencils.interpolation_astnodes import TextureAccess
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True USE_FAST_MATH = True
...@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen ...@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
func(*args, **block_and_thread_numbers) func(*args, **block_and_thread_numbers)
# import pycuda.driver as cuda # import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node ast = kernel_function_node
wrapper.parameters = kernel_function_node.get_parameters() parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs wrapper.num_regs = func.num_regs
return wrapper return wrapper
......
"""
Light-weight wrapper around a compiled kernel
"""
import pystencils
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
self.num_regs = None
def __call__(self, **kwargs):
return self.kernel(**kwargs)
@property
def code(self):
return str(pystencils.show_code(self.ast))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment