Skip to content
Snippets Groups Projects
Commit de4f8b5d authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Append assignments to KernelFunction (for later analysis etc.)

parent a5ab1070
Branches
No related merge requests found
......@@ -172,7 +172,7 @@ class KernelFunction(Node):
def field_name(self):
return self.fields[0].name
def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"):
def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel", assignments=None):
super(KernelFunction, self).__init__()
self._body = body
body.parent = self
......@@ -186,6 +186,7 @@ class KernelFunction(Node):
self.instruction_set = None # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
# function that compiles the node to a Python callable, is set by the backends
self._compile_function = compile_function
self.assignments = assignments
@property
def target(self):
......
......@@ -66,7 +66,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
ghost_layers=ghost_layers, loop_order=loop_order)
ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function,
ghost_layers=ghost_layer_info, function_name=function_name)
ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
implement_interpolations(body)
if split_groups:
......@@ -147,7 +147,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
function_body = Block([loop_node])
ast_node = KernelFunction(function_body, "cpu", "c", make_python_function,
ghost_layers=None, function_name=function_name)
ghost_layers=None, function_name=function_name, assignments=assignments)
fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}
......
......@@ -62,7 +62,13 @@ def create_cuda_kernel(assignments,
block = indexing.guard(block, common_shape)
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
ast = KernelFunction(block, 'gpu', 'gpucuda', make_python_function, ghost_layers, function_name)
ast = KernelFunction(block,
'gpu',
'gpucuda',
make_python_function,
ghost_layers,
function_name,
assignments=assignments)
ast.global_variables.update(indexing.index_variables)
implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)
......@@ -137,7 +143,8 @@ def created_indexed_cuda_kernel(assignments,
function_body = Block(coordinate_symbol_assignments + assignments)
function_body = indexing.guard(function_body, get_common_shape(index_fields))
ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function, None, function_name)
ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function,
None, function_name, assignments=assignments)
ast.global_variables.update(indexing.index_variables)
implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment