diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 8ec80917af20aefc5883c68ac51f360005b4350b..6d1501529ae55cf3d6aeb73cdb0b459958ae4da5 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -172,7 +172,7 @@ class KernelFunction(Node):
         def field_name(self):
             return self.fields[0].name
 
-    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"):
+    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel", assignments=None):
         super(KernelFunction, self).__init__()
         self._body = body
         body.parent = self
@@ -186,6 +186,7 @@ class KernelFunction(Node):
         self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
         # function that compiles the node to a Python callable, is set by the backends
         self._compile_function = compile_function
+        self.assignments = assignments
 
     @property
     def target(self):
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index 9b119ea9a308726b8225ece66cddcb80ee3a4ef3..38ce169af754fae27d80cadfa993ac83a03e0999 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -66,7 +66,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
     loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
                                                         ghost_layers=ghost_layers, loop_order=loop_order)
     ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function,
-                              ghost_layers=ghost_layer_info, function_name=function_name)
+                              ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
     implement_interpolations(body)
 
     if split_groups:
@@ -147,7 +147,7 @@ def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, fu
 
     function_body = Block([loop_node])
     ast_node = KernelFunction(function_body, "cpu", "c", make_python_function,
-                              ghost_layers=None, function_name=function_name)
+                              ghost_layers=None, function_name=function_name, assignments=assignments)
 
     fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}
 
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index d002cb0f050c0faaa9ea987d49ff7d0385df7cbe..33db3ad56da5d5e26f43403cf354c3814b2805e9 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -62,7 +62,13 @@ def create_cuda_kernel(assignments,
     block = indexing.guard(block, common_shape)
     unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
 
-    ast = KernelFunction(block, 'gpu', 'gpucuda', make_python_function, ghost_layers, function_name)
+    ast = KernelFunction(block,
+                         'gpu',
+                         'gpucuda',
+                         make_python_function,
+                         ghost_layers,
+                         function_name,
+                         assignments=assignments)
     ast.global_variables.update(indexing.index_variables)
 
     implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)
@@ -137,7 +143,8 @@ def created_indexed_cuda_kernel(assignments,
 
     function_body = Block(coordinate_symbol_assignments + assignments)
     function_body = indexing.guard(function_body, get_common_shape(index_fields))
-    ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function, None, function_name)
+    ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function,
+                         None, function_name, assignments=assignments)
     ast.global_variables.update(indexing.index_variables)
 
     implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)