diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index e4ade2f1bb66d46c35ee4234e8417ecfa189215c..b5cfc5a4eda444f615474503ecbb79800c8e3969 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -1,4 +1,3 @@
-import itertools
 from types import MappingProxyType
 
 import sympy as sp
@@ -192,120 +191,7 @@ def create_indexed_kernel(assignments,
         raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
 
 
-def create_staggered_kernel(*args, **kwargs):
-    """Kernel that updates a staggered field. Dispatches to either create_staggered_kernel_1 or
-       create_staggered_kernel_2 depending on the argument types.
-    """
-    if 'staggered_field' in kwargs or type(args[0]) is Field:
-        return create_staggered_kernel_1(*args, **kwargs)
-    else:
-        return create_staggered_kernel_2(*args, **kwargs)
-
-
-def create_staggered_kernel_1(staggered_field, expressions, subexpressions=(), target='cpu',
-                              gpu_exclusive_conditions=False, **kwargs):
-    """Kernel that updates a staggered field.
-
-    .. image:: /img/staggered_grid.svg
-
-    Args:
-        staggered_field: field where the first index coordinate defines the location of the staggered value
-                can have 1 or 2 index coordinates, in case of two index coordinates at every staggered location
-                a vector is stored, expressions parameter has to be a sequence of sequences then
-                where e.g. ``f[0,0](0)`` is interpreted as value at the left cell boundary, ``f[1,0](0)`` the right cell
-                boundary and ``f[0,0](1)`` the southern cell boundary etc.
-        expressions: sequence of expressions of length dim, defining how the west, southern, (bottom) cell boundary
-                     should be updated.
-        subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
-        target: 'cpu' or 'gpu'
-        gpu_exclusive_conditions: if/else construct to have only one code block for each of 2**dim code paths
-        kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
-
-    Returns:
-        AST, see `create_kernel`
-    """
-    assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
-    assert staggered_field.index_dimensions in (1, 2), 'Staggered field must have one or two index dimensions'
-    dim = staggered_field.spatial_dimensions
-
-    counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
-    conditions = [counters[i] < staggered_field.shape[i] - 1 for i in range(dim)]
-    assert len(expressions) == dim
-    if staggered_field.index_dimensions == 2:
-        assert all(len(sublist) == len(expressions[0]) for sublist in expressions), \
-            "If staggered field has two index dimensions expressions has to be a sequence of sequences of all the " \
-            "same length."
-
-    final_assignments = []
-    last_conditional = None
-
-    def add(condition, dimensions, as_else_block=False):
-        nonlocal last_conditional
-        if staggered_field.index_dimensions == 1:
-            assignments = [Assignment(staggered_field(d), expressions[d]) for d in dimensions]
-            a_coll = AssignmentCollection(assignments, list(subexpressions))
-            a_coll = a_coll.new_filtered([staggered_field(d) for d in dimensions])
-        elif staggered_field.index_dimensions == 2:
-            assert staggered_field.has_fixed_index_shape
-            assignments = [Assignment(staggered_field(d, i), expr)
-                           for d in dimensions
-                           for i, expr in enumerate(expressions[d])]
-            a_coll = AssignmentCollection(assignments, list(subexpressions))
-            a_coll = a_coll.new_filtered([staggered_field(d, i) for i in range(staggered_field.index_shape[1])
-                                          for d in dimensions])
-        sp_assignments = [SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments]
-        if as_else_block and last_conditional:
-            new_cond = Conditional(condition, Block(sp_assignments))
-            last_conditional.false_block = Block([new_cond])
-            last_conditional = new_cond
-        else:
-            last_conditional = Conditional(condition, Block(sp_assignments))
-            final_assignments.append(last_conditional)
-
-    if target == 'cpu' or not gpu_exclusive_conditions:
-        for d in range(dim):
-            cond = sp.And(*[conditions[i] for i in range(dim) if d != i])
-            add(cond, [d])
-    elif target == 'gpu':
-        full_conditions = [sp.And(*[conditions[i] for i in range(dim) if d != i]) for d in range(dim)]
-        for include in itertools.product(*[[1, 0]] * dim):
-            case_conditions = sp.And(*[c if value else sp.Not(c) for c, value in zip(full_conditions, include)])
-            dimensions_to_include = [i for i in range(dim) if include[i]]
-            if dimensions_to_include:
-                add(case_conditions, dimensions_to_include, True)
-
-    ghost_layers = [(1, 0)] * dim
-
-    blocking = kwargs.get('cpu_blocking', None)
-    if blocking:
-        del kwargs['cpu_blocking']
-
-    cpu_vectorize_info = kwargs.get('cpu_vectorize_info', None)
-    if cpu_vectorize_info:
-        del kwargs['cpu_vectorize_info']
-    openmp = kwargs.get('cpu_openmp', None)
-    if openmp:
-        del kwargs['cpu_openmp']
-
-    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
-
-    if target == 'cpu':
-        remove_conditionals_in_staggered_kernel(ast)
-        move_constants_before_loop(ast)
-        omp_collapse = None
-        if blocking:
-            omp_collapse = loop_blocking(ast, blocking)
-        if openmp:
-            from pystencils.cpu import add_openmp
-            add_openmp(ast, num_threads=openmp, collapse=omp_collapse, assume_single_outer_loop=False)
-        if cpu_vectorize_info is True:
-            vectorize(ast)
-        elif isinstance(cpu_vectorize_info, dict):
-            vectorize(ast, **cpu_vectorize_info)
-    return ast
-
-
-def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwargs):
+def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwargs):
     """Kernel that updates a staggered field.
 
     .. image:: /img/staggered_grid.svg
@@ -330,15 +216,15 @@ def create_staggered_kernel_2(assignments, gpu_exclusive_conditions=False, **kwa
     subexpressions = ()
     if isinstance(assignments, AssignmentCollection):
         subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
-                                                       if type(a.lhs) is not Field.Access and
-                                                       not FieldType.is_staggered(a.lhs.field)]
-        assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access and
-                       FieldType.is_staggered(a.lhs.field)]
+                                                       if type(a.lhs) is not Field.Access
+                                                       and not FieldType.is_staggered(a.lhs.field)]
+        assignments = [a for a in assignments.main_assignments if type(a.lhs) is Field.Access
+                       and FieldType.is_staggered(a.lhs.field)]
     else:
-        subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access and
-                          not FieldType.is_staggered(a.lhs.field)]
-        assignments = [a for a in assignments if type(a.lhs) is Field.Access and
-                       FieldType.is_staggered(a.lhs.field)]
+        subexpressions = [a for a in assignments if type(a.lhs) is not Field.Access
+                          and not FieldType.is_staggered(a.lhs.field)]
+        assignments = [a for a in assignments if type(a.lhs) is Field.Access 
+                       and FieldType.is_staggered(a.lhs.field)]
     if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
         raise ValueError("All assignments need to be made to staggered fields with the same stencil")
     if len(set([a.lhs.field.shape for a in assignments])) != 1:
diff --git a/pystencils_tests/test_blocking_staggered.py b/pystencils_tests/test_blocking_staggered.py
index 76ec8abf0e3b76d30415571bb43e975f317e9a3f..a79efe7c4445faa9baeb8323383b382a42f2cf33 100644
--- a/pystencils_tests/test_blocking_staggered.py
+++ b/pystencils_tests/test_blocking_staggered.py
@@ -11,8 +11,9 @@ def test_blocking_staggered():
        f[0, 0, 0] - f[0, -1, 0],
        f[0, 0, 0] - f[0, 0, -1],
     ]
-    kernel = ps.create_staggered_kernel(stag, terms, cpu_blocking=(3, 16, 8)).compile()
-    reference_kernel = ps.create_staggered_kernel(stag, terms).compile()
+    assignments = [ps.Assignment(stag.staggered_access(d), terms[i]) for i, d in enumerate(stag.staggered_stencil)]
+    kernel = ps.create_staggered_kernel(assignments, cpu_blocking=(3, 16, 8)).compile()
+    reference_kernel = ps.create_staggered_kernel(assignments).compile()
     print(ps.show_code(kernel.ast))
 
     f_arr = np.random.rand(80, 33, 19)
diff --git a/pystencils_tests/test_loop_cutting.py b/pystencils_tests/test_loop_cutting.py
index 999e7b52a8b40111243c09aca1aa3fc1549a0cc2..cd89f37f6f365b4223e1463db68874f50e81c46d 100644
--- a/pystencils_tests/test_loop_cutting.py
+++ b/pystencils_tests/test_loop_cutting.py
@@ -55,7 +55,8 @@ def test_staggered_iteration():
         for d in range(dim):
             expressions.append(sum(f[o] for o in offsets_in_plane(d, 0, dim)) -
                                sum(f[o] for o in offsets_in_plane(d, -1, dim)))
-        func_optimized = create_staggered_kernel(s, expressions).compile()
+        assignments = [ps.Assignment(s.staggered_access(d), expressions[i]) for i, d in enumerate(s.staggered_stencil)]
+        func_optimized = create_staggered_kernel(assignments).compile()
         assert not func_optimized.ast.atoms(Conditional), "Loop cutting optimization did not work"
 
         func(f=f_arr, s=s_arr_ref)
@@ -111,8 +112,10 @@ def test_staggered_gpu():
     s = ps.fields("s({dim}): double[{dim}D]".format(dim=dim), field_type=FieldType.STAGGERED)
     expressions = [(f[0, 0] + f[-1, 0]) / 2,
                    (f[0, 0] + f[0, -1]) / 2]
-    kernel_ast = ps.create_staggered_kernel(s, expressions, target='gpu', gpu_exclusive_conditions=True)
+    assignments = [ps.Assignment(s.staggered_access(d), expressions[i]) for i, d in enumerate(s.staggered_stencil)]
+    kernel_ast = ps.create_staggered_kernel(assignments, target='gpu', gpu_exclusive_conditions=True)
     assert len(kernel_ast.atoms(Conditional)) == 4
 
-    kernel_ast = ps.create_staggered_kernel(s, expressions, target='gpu', gpu_exclusive_conditions=False)
+    assignments = [ps.Assignment(s.staggered_access(d), expressions[i]) for i, d in enumerate(s.staggered_stencil)]
+    kernel_ast = ps.create_staggered_kernel(assignments, target='gpu', gpu_exclusive_conditions=False)
     assert len(kernel_ast.atoms(Conditional)) == 3