diff --git a/pystencils/fd/__init__.py b/pystencils/fd/__init__.py
index f47d76a07b195a836e6dcf35bd3d13629308f72c..62b11e4e79f6abe60d0352fb7e3888b02774a0f9 100644
--- a/pystencils/fd/__init__.py
+++ b/pystencils/fd/__init__.py
@@ -3,9 +3,10 @@ from .derivative import Diff, DiffOperator, \
     expand_diff_full, expand_diff_linear, expand_diff_products, combine_diff_products, \
     functional_derivative, diff
 from .finitedifferences import advection, diffusion, transient, Discretization2ndOrder
-from .spatial import discretize_spatial
+from .spatial import discretize_spatial, discretize_spatial_staggered
 
 __all__ = ['Diff', 'diff', 'DiffOperator', 'diff_terms', 'collect_diffs',
            'zero_diffs', 'evaluate_diffs', 'normalize_diff_order', 'expand_diff_full', 'expand_diff_linear',
            'expand_diff_products', 'combine_diff_products', 'functional_derivative',
-           'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial']
+           'advection', 'diffusion', 'transient', 'Discretization2ndOrder', 'discretize_spatial',
+           'discretize_spatial_staggered']
diff --git a/pystencils/fd/finitedifferences.py b/pystencils/fd/finitedifferences.py
index fe764361e9bad3b59b11ea4055b8b27e3662fb99..27ae96fb96fe6c00c8cbd4109df7caab98dc8849 100644
--- a/pystencils/fd/finitedifferences.py
+++ b/pystencils/fd/finitedifferences.py
@@ -4,6 +4,8 @@ from typing import Union, Optional
 
 from pystencils import Field, AssignmentCollection
 from pystencils.fd import Diff
+from pystencils.fd.derivative import diff_args
+from pystencils.fd.spatial import fd_stencils_standard
 from pystencils.sympyextensions import fast_subs
 
 
@@ -68,9 +70,10 @@ def transient(scalar, idx=None):
 
 
 class Discretization2ndOrder:
-    def __init__(self, dx=sp.Symbol("dx"), dt=sp.Symbol("dt")):
+    def __init__(self, dx=sp.Symbol("dx"), dt=sp.Symbol("dt"), discretization_stencil_func=fd_stencils_standard):
         self.dx = dx
         self.dt = dt
+        self.spatial_stencil = discretization_stencil_func
 
     @staticmethod
     def _diff_order(e):
@@ -104,7 +107,10 @@ class Discretization2ndOrder:
         elif isinstance(e, Advection):
             return self._discretize_advection(e)
         elif isinstance(e, Diff):
-            return self._discretize_diff(e)
+            arg, *indices = diff_args(e)
+            if not isinstance(arg, Field.Access):
+                raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
+            return self.spatial_stencil(indices, self.dx, arg)
         else:
             new_args = [self._discretize_spatial(a) for a in e.args]
             return e.func(*new_args) if new_args else e
diff --git a/pystencils/fd/spatial.py b/pystencils/fd/spatial.py
index 610756f12a1e453c64078083400c7f4e072d97b8..e3ef030d1fa64ac619f2a8acedc8504e6c283638 100644
--- a/pystencils/fd/spatial.py
+++ b/pystencils/fd/spatial.py
@@ -1,9 +1,12 @@
 from typing import Tuple
 import sympy as sp
 from functools import partial
+
+from pystencils.astnodes import LoopOverCoordinate
 from pystencils.cache import memorycache
 from pystencils import AssignmentCollection, Field
 from pystencils.fd import Diff
+from pystencils.transformations import generic_visit
 from .derivative import diff_args
 from .derivation import FiniteDifferenceStencilDerivation
 
@@ -107,22 +110,59 @@ def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
         else:
             raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'")
 
-    if isinstance(expr, list):
-        return [discretize_spatial(e, dx, stencil) for e in expr]
-    elif isinstance(expr, sp.Matrix):
-        return expr.applyfunc(partial(discretize_spatial, dx=dx, stencil=stencil))
-    elif isinstance(expr, AssignmentCollection):
-        return expr.copy(main_assignments=[e for e in expr.main_assignments],
-                         subexpressions=[e for e in expr.subexpressions])
-    elif isinstance(expr, Diff):
-        arg, *indices = diff_args(expr)
-        if not isinstance(arg, Field.Access):
-            raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
-        return stencil(indices, dx, arg)
-    else:
-        new_args = [discretize_spatial(a, dx, stencil) for a in expr.args]
-        return expr.func(*new_args) if new_args else expr
+    def visitor(e):
+        if isinstance(e, Diff):
+            arg, *indices = diff_args(e)
+            if not isinstance(arg, Field.Access):
+                raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
+            return stencil(indices, dx, arg)
+        else:
+            new_args = [discretize_spatial(a, dx, stencil) for a in e.args]
+            return e.func(*new_args) if new_args else e
+
+    return generic_visit(expr, visitor)
+
+
+def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
+
+    def staggered_visitor(e, coordinate, sign):
+        if isinstance(e, Diff):
+            arg, *indices = diff_args(e)
+            if len(indices) != 1:
+                raise ValueError("Function supports only up to second derivatives")
+            if not isinstance(arg, Field.Access):
+                raise ValueError("Argument of inner derivative has to be field access")
+            target = indices[0]
+            if target == coordinate:
+                assert sign in (-1, 1)
+                return (arg.neighbor(coordinate, sign) - arg) / dx * sign
+            else:
+                return (stencil(indices, dx, arg.neighbor(coordinate, sign))
+                        + stencil(indices, dx, arg)) / 2
+        elif isinstance(e, Field.Access):
+            return (e.neighbor(coordinate, sign) + e) / 2
+        elif isinstance(e, sp.Symbol):
+            loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e)
+            return e + sign / 2 if loop_idx == coordinate else e
+        else:
+            new_args = [staggered_visitor(a, coordinate, sign) for a in e.args]
+            return e.func(*new_args) if new_args else e
+
+    def visitor(e):
+        if isinstance(e, Diff):
+            arg, *indices = diff_args(e)
+            if isinstance(arg, Field.Access):
+                return stencil(indices, dx, arg)
+            else:
+                if not len(indices) == 1:
+                    raise ValueError("This term is not support by the staggered discretization strategy")
+                target = indices[0]
+                return (staggered_visitor(arg, target, 1) - staggered_visitor(arg, target, -1)) / dx
+        else:
+            new_args = [visitor(a) for a in e.args]
+            return e.func(*new_args) if new_args else e
 
+    return generic_visit(expr, visitor)
 
 # -------------------------------------- special stencils --------------------------------------------------------------
 
diff --git a/pystencils/field.py b/pystencils/field.py
index 4b06592e1aafc298da5f066cbfece678c945c39f..bd6b0337db25e23827f9943cd85fce02359c8000 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -532,7 +532,7 @@ class Field:
             """Value of index coordinates as tuple."""
             return self._index
 
-        def neighbor(self, coord_id: int, offset: Sequence[int]) -> 'Field.Access':
+        def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
             """Returns a new Access with changed spatial coordinates.
 
             Args:
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index e781b3bc044ad56cc50ec6644404ea93e77a4563..6fc9f83c5fcffcef50cfac0f88a2d579058a3452 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -7,6 +7,8 @@ import hashlib
 import sympy as sp
 from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
+
+from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.assignment import Assignment
 from pystencils.field import Field, FieldType
 from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, reinterpret_cast_func, \
@@ -80,6 +82,21 @@ def filtered_tree_iteration(node, node_type, stop_type=None):
         yield from filtered_tree_iteration(arg, node_type)
 
 
+def generic_visit(term, visitor):
+    if isinstance(term, AssignmentCollection):
+        new_main_assignments = generic_visit(term.main_assignments, visitor)
+        new_subexpressions = generic_visit(term.subexpressions, visitor)
+        return term.copy(new_main_assignments, new_subexpressions)
+    elif isinstance(term, list):
+        return [generic_visit(e, visitor) for e in term]
+    elif isinstance(term, Assignment):
+        return Assignment(term.lhs, generic_visit(term.rhs, visitor))
+    elif isinstance(term, sp.Matrix):
+        return term.applyfunc(lambda e: generic_visit(e, visitor))
+    else:
+        return visitor(term)
+
+
 def unify_shape_symbols(body, common_shape, fields):
     """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
 
diff --git a/pystencils_tests/test_finite_differences.py b/pystencils_tests/test_finite_differences.py
index 5b7c504dc7d4144d32de5dfcddea69f245658a32..dcf9fa2b735c8512424918052c63348b51355cd7 100644
--- a/pystencils_tests/test_finite_differences.py
+++ b/pystencils_tests/test_finite_differences.py
@@ -1,5 +1,6 @@
 import sympy as sp
 import pystencils as ps
+from pystencils.astnodes import LoopOverCoordinate
 from pystencils.stencils import stencil_coefficients
 from pystencils.fd.spatial import fd_stencils_standard, fd_stencils_isotropic, discretize_spatial
 from pystencils.fd import diff
@@ -38,3 +39,32 @@ def test_spatial_1d_unit_sum():
             discretized = discretize_spatial(term, dx=h, stencil=scheme)
             _, coefficients = stencil_coefficients(discretized)
             assert sum(coefficients) == 0
+
+
+def test_staggered_laplacian():
+    f = ps.fields("f : double[2D]")
+    a, dx = sp.symbols("a, dx")
+
+    factored_version = sum(ps.fd.Diff(a * ps.fd.Diff(f[0, 0], i), i)
+                           for i in range(2))
+    expanded = ps.fd.expand_diff_full(factored_version, constants=[a])
+
+    reference = ps.fd.discretize_spatial(expanded, dx).factor()
+    to_test = ps.fd.discretize_spatial_staggered(factored_version, dx).factor()
+    assert reference == to_test
+
+
+def test_staggered_combined():
+    from pystencils.fd import diff
+    f = ps.fields("f : double[2D]")
+    x, y = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(2)]
+    dx = sp.symbols("dx")
+
+    expr = diff(x * diff(f, 0) + y * diff(f, 1), 0)
+
+    right = (x + sp.Rational(1, 2)) * (f[1, 0] - f[0, 0]) + y * (f[1, 1] - f[1, -1] + f[0, 1] - f[0, -1]) / 4
+    left = (x - sp.Rational(1, 2)) * (f[0, 0] - f[-1, 0]) + y * (f[-1, 1] - f[-1, -1] + f[0, 1] - f[0, -1]) / 4
+    reference = (right - left) / (dx ** 2)
+
+    to_test = ps.fd.discretize_spatial_staggered(expr, dx)
+    assert sp.expand(reference - to_test) == 0