diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 866f1a4ef5357e6a1a74d21b174d24785c398677..198926e7f013f22881e30e7f67203ecf377543c1 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -8,6 +8,7 @@ from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAss
 from pystencils.cpu.vectorization import vectorize
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
+from pystencils.stencil import direction_string_to_offset, inverse_direction_string
 from pystencils.transformations import (
     loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
 
@@ -287,3 +288,64 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
         elif isinstance(cpu_vectorize_info, dict):
             vectorize(ast, **cpu_vectorize_info)
     return ast
+
+
+def create_staggered_kernel_2(assignments, **kwargs):
+    """Kernel that updates a staggered field.
+
+    .. image:: /img/staggered_grid.svg
+
+    For a staggered field, the first index coordinate defines the location of the staggered value.
+    Further index coordinates can be used to store vectors/tensors at each point.
+
+    Args:
+        assignments: a sequence of assignments or AssignmentCollection with one item for each staggered grid point.
+                     When storing vectors/tensors, the number of items expected is multiplied with the number of
+                     components.
+        kwargs: passed directly to create_kernel
+    """
+    assert 'ghost_layers' not in kwargs
+
+    subexpressions = ()
+    if isinstance(assignments, AssignmentCollection):
+        assignments = assignments.main_assignments
+        subexpressions = assignments.subexpressions
+    if len(set([a.lhs.field for a in assignments])) != 1:
+        raise ValueError("All assignments need to be made to the same staggered field")
+    staggered_field = assignments[0].lhs.field
+    dim = staggered_field.spatial_dimensions
+    points = staggered_field.index_shape[0]
+    values_per_point = sp.Mul(*staggered_field.index_shape[1:])
+    assert len(assignments) == points * values_per_point
+
+    counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
+
+    final_assignments = []
+
+    def condition(direction):
+        """exclude those staggered points that correspond to fluxes between ghost cells"""
+        exclusions = set(["E", "W", "N", "S"])
+        if dim == 3:
+            exclusions.update("T", "B")
+
+        for elementary_direction in direction:
+            exclusions.remove(inverse_direction_string(elementary_direction))
+        conditions = []
+        for e in exclusions:
+            offset = direction_string_to_offset(e)
+            for i, o in enumerate(offset):
+                if o == 1:
+                    conditions.append(counters[i] < staggered_field.shape[i] - 1)
+                elif o == -1:
+                    conditions.append(counters[i] > 0)
+        return sp.And(*conditions)
+
+    for d, direction in zip(range(points), staggered_field.staggered_stencil):
+        sp_assignments = [SympyAssignment(assignments[d].lhs, assignments[d].rhs)] + \
+                         [SympyAssignment(s.lhs, s.rhs) for s in subexpressions]
+        last_conditional = Conditional(condition(direction), Block(sp_assignments))
+        final_assignments.append(last_conditional)
+
+    ghost_layers = [(1, 0)] * dim
+    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, **kwargs)
+    return ast
diff --git a/pystencils/stencil.py b/pystencils/stencil.py
index 9f70336f3a3b3043fad322bdc67e134d5e11906b..32b1283fd969deba2f45df1ed4526cb000817072 100644
--- a/pystencils/stencil.py
+++ b/pystencils/stencil.py
@@ -16,6 +16,11 @@ def inverse_direction(direction):
     return tuple([-i for i in direction])
 
 
+def inverse_direction_string(direction):
+    """Returns inverse of given direction string"""
+    return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))
+
+
 def is_valid(stencil, max_neighborhood=None):
     """
     Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
diff --git a/pystencils_tests/test_staggered_diffusion.py b/pystencils_tests/test_staggered_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..458e0fd79c9791738310db15a7b5b91b5d0df33a
--- /dev/null
+++ b/pystencils_tests/test_staggered_diffusion.py
@@ -0,0 +1,69 @@
+import pystencils as ps
+import numpy as np
+
+
+class TestDiffusion:
+    def _run(self, num_neighbors):
+        L = (40, 40)
+        D = 0.066
+        dt = 1
+        T = 100
+
+        dh = ps.create_data_handling(L, periodicity=True, default_target='cpu')
+
+        c = dh.add_array('c', values_per_cell=1)
+        j = dh.add_array('j', values_per_cell=num_neighbors, field_type=ps.FieldType.STAGGERED_FLUX)
+
+        x_staggered = - c[-1, 0] + c[0, 0]
+        y_staggered = - c[0, -1] + c[0, 0]
+        xy_staggered = - c[-1, -1] + c[0, 0]
+        xY_staggered = - c[-1, 1] + c[0, 0]
+
+        jj = j.staggered_access
+        divergence = -1 * D / (1 + np.sqrt(2) if j.index_shape[0] == 4 else 1) * \
+            sum([jj(d) for d in j.staggered_stencil +
+                [ps.stencil.inverse_direction_string(d) for d in j.staggered_stencil]])
+
+        update = [ps.Assignment(c.center, c.center + dt * divergence)]
+        flux = [ps.Assignment(j.staggered_access("W"), x_staggered),
+                ps.Assignment(j.staggered_access("S"), y_staggered)]
+        if j.index_shape[0] == 4:
+            flux += [ps.Assignment(j.staggered_access("SW"), xy_staggered),
+                     ps.Assignment(j.staggered_access("NW"), xY_staggered)]
+
+        staggered_kernel = ps.kernelcreation.create_staggered_kernel_2(flux, target=dh.default_target).compile()
+        div_kernel = ps.create_kernel(update, target=dh.default_target).compile()
+
+        def time_loop(steps):
+            sync = dh.synchronization_function([c.name])
+            dh.all_to_gpu()
+            for i in range(steps):
+                sync()
+                dh.run_kernel(staggered_kernel)
+                dh.run_kernel(div_kernel)
+            dh.all_to_cpu()
+
+        def init():
+            dh.fill(c.name, 0)
+            dh.fill(j.name, np.nan)
+            dh.cpu_arrays[c.name][L[0] // 2:L[0] // 2 + 2, L[1] // 2:L[1] // 2 + 2] = 1.0
+
+        init()
+        time_loop(T)
+
+        reference = np.empty(L)
+        for x in range(L[0]):
+            for y in range(L[1]):
+                r = np.array([x, y]) - L[0] / 2 + 0.5
+                reference[x, y] = (4 * np.pi * D * T)**(-dh.dim / 2) * np.exp(-np.dot(r, r) / (4 * D * T)) * (2**dh.dim)
+
+        if num_neighbors == 2:
+            assert np.abs(dh.gather_array(c.name) - reference).max() < 1e-3
+        else:
+            assert np.abs(dh.gather_array(c.name) - reference).max() < 1e-2
+
+    def test_diffusion_2(self):
+        self._run(2)
+
+    def test_diffusion_4(self):
+        self._run(4)