diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 10e855df2fd9a661f053adbdb43dbfb8d969fdf4..b081f752945133d56a4d32407e335b26d13614ec 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -310,6 +310,7 @@ class Block(Node):
 
     def insert_before(self, new_node, insert_before):
         new_node.parent = self
+        assert self._nodes.count(insert_before) == 1
         idx = self._nodes.index(insert_before)
 
         # move all assignment (definitions to the top)
@@ -337,6 +338,7 @@ class Block(Node):
         return tmp
 
     def replace(self, child, replacements):
+        assert self._nodes.count(child) == 1
         idx = self._nodes.index(child)
         del self._nodes[idx]
         if type(replacements) is list:
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 57a97bcb830af86cd702c412a816917bd0131e37..b968d4c621faca9012d0a5997e877c36b2ae54d8 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -287,7 +287,9 @@ def create_staggered_kernel(assignments, gpu_exclusive_conditions=False, **kwarg
         last_conditional = Conditional(condition(direction), Block(sp_assignments))
         final_assignments.append(last_conditional)
 
-    prepend_optimizations = [remove_conditionals_in_staggered_kernel, move_constants_before_loop]
+    remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers])
+    prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional),
+                             move_constants_before_loop]
     ast = create_kernel(final_assignments, ghost_layers=ghost_layers, cpu_prepend_optimizations=prepend_optimizations,
                         **kwargs)
     return ast
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index ea340e930631a159cf81cf9d56337b315e27c88b..762c36136cd7f3eb541a6075f4b24021813c82ad 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -630,13 +630,18 @@ def move_constants_before_loop(ast_node):
                     else:
                         target.insert_before(child, child_to_insert_before)
                 elif exists_already and exists_already.rhs == child.rhs:
-                    pass
+                    if target.args.index(exists_already) > target.args.index(child_to_insert_before):
+                        assert target.args.count(exists_already) == 1
+                        assert target.args.count(child_to_insert_before) == 1
+                        target.args.remove(exists_already)
+                        target.insert_before(exists_already, child_to_insert_before)
                 else:
                     # this variable already exists in outer block, but with different rhs
                     # -> symbol has to be renamed
                     assert isinstance(child.lhs, TypedSymbol)
                     new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
-                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
+                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
+                                         child_to_insert_before)
                     substitute_variables[child.lhs] = new_symbol
 
 
@@ -1064,15 +1069,19 @@ def insert_casts(node):
     return node.func(*args)
 
 
-def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
-    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
+def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
+    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
+       first and last element"""
 
     all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
     assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
     inner_loop = all_inner_loops.pop()
 
     for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
-        cut_loop(loop, [loop.stop - 1])
+        if include_first:
+            cut_loop(loop, [loop.start + 1, loop.stop - 1])
+        else:
+            cut_loop(loop, [loop.stop - 1])
 
     simplify_conditionals(function_node.body, loop_counter_simplification=True)
     cleanup_blocks(function_node.body)
diff --git a/pystencils_tests/test_staggered_kernel.py b/pystencils_tests/test_staggered_kernel.py
index a914b91806ee53ab200101edfccfcecaae81207b..4db538bf8f83d6779cb7dfd612fa9928acc23965 100644
--- a/pystencils_tests/test_staggered_kernel.py
+++ b/pystencils_tests/test_staggered_kernel.py
@@ -75,3 +75,11 @@ def test_staggered_subexpressions():
     assignments = [ps.Assignment(j.staggered_access("W"), c),
                    ps.Assignment(c, 1)]
     ps.create_staggered_kernel(assignments, target=dh.default_target).compile()
+
+
+def test_staggered_loop_cutting():
+    dh = ps.create_data_handling((4, 4), periodicity=True, default_target='cpu')
+    j = dh.add_array('j', values_per_cell=4, field_type=ps.FieldType.STAGGERED)
+    assignments = [ps.Assignment(j.staggered_access("SW"), 1)]
+    ast = ps.create_staggered_kernel(assignments, target=dh.default_target)
+    assert not ast.atoms(ps.astnodes.Conditional)