diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 8ec80917af20aefc5883c68ac51f360005b4350b..6d1501529ae55cf3d6aeb73cdb0b459958ae4da5 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -172,7 +172,7 @@ class KernelFunction(Node): def field_name(self): return self.fields[0].name - def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"): + def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel", assignments=None): super(KernelFunction, self).__init__() self._body = body body.parent = self @@ -186,6 +186,7 @@ class KernelFunction(Node): self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use # function that compiles the node to a Python callable, is set by the backends self._compile_function = compile_function + self.assignments = assignments @property def target(self): diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index 9b119ea9a308726b8225ece66cddcb80ee3a4ef3..38ce169af754fae27d80cadfa993ac83a03e0999 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -66,7 +66,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice, ghost_layers=ghost_layers, loop_order=loop_order) ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function, - ghost_layers=ghost_layer_info, function_name=function_name) + ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments) implement_interpolations(body) if split_groups: @@ -147,7 +147,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu function_body = Block([loop_node]) ast_node = KernelFunction(function_body, "cpu", "c", make_python_function, - ghost_layers=None, function_name=function_name) + ghost_layers=None, function_name=function_name, assignments=assignments) fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields} diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index d002cb0f050c0faaa9ea987d49ff7d0385df7cbe..33db3ad56da5d5e26f43403cf354c3814b2805e9 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -62,7 +62,13 @@ def create_cuda_kernel(assignments, block = indexing.guard(block, common_shape) unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) - ast = KernelFunction(block, 'gpu', 'gpucuda', make_python_function, ghost_layers, function_name) + ast = KernelFunction(block, + 'gpu', + 'gpucuda', + make_python_function, + ghost_layers, + function_name, + assignments=assignments) ast.global_variables.update(indexing.index_variables) implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation) @@ -137,7 +143,8 @@ def created_indexed_cuda_kernel(assignments, function_body = Block(coordinate_symbol_assignments + assignments) function_body = indexing.guard(function_body, get_common_shape(index_fields)) - ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function, None, function_name) + ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function, + None, function_name, assignments=assignments) ast.global_variables.update(indexing.index_variables) implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)