diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst
index 1e3968bc0e4137b7a43796911de646cdddcb9ab6..df194bde9b9c50cc9ee3f1064ff5b5361205c227 100644
--- a/docs/source/backend/index.rst
+++ b/docs/source/backend/index.rst
@@ -14,6 +14,7 @@ who wish to customize or extend the behaviour of the code generator in their app
     iteration_space
     translation
     platforms
+    transformations
     jit
 
 Internal Representation
diff --git a/docs/source/backend/transformations.rst b/docs/source/backend/transformations.rst
new file mode 100644
index 0000000000000000000000000000000000000000..44bf4da23e160edaac5e2dd9918fbd389aba94d6
--- /dev/null
+++ b/docs/source/backend/transformations.rst
@@ -0,0 +1,7 @@
+*******************
+AST Transformations
+*******************
+
+`pystencils.backend.transformations`
+
+.. automodule:: pystencils.backend.transformations
diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py
index c2334f54c34d476207eddc5466b2b13bff0d39d8..83c406b0a99d52cee9599f321d6c32477f6dbf8a 100644
--- a/src/pystencils/backend/kernelcreation/ast_factory.py
+++ b/src/pystencils/backend/kernelcreation/ast_factory.py
@@ -1,10 +1,11 @@
 from typing import Any, Sequence, cast, overload
 
+import numpy as np
 import sympy as sp
 from sympy.codegen.ast import AssignmentBase
 
 from ..ast import PsAstNode
-from ..ast.expressions import PsExpression, PsSymbolExpr
+from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr
 from ..ast.structural import PsLoop, PsBlock, PsAssignment
 
 from ..symbols import PsSymbol
@@ -16,6 +17,10 @@ from .typification import Typifier
 from .iteration_space import FullIterationSpace
 
 
+IndexParsable = PsExpression | PsSymbol | PsConstant | sp.Expr | int | np.integer
+_IndexParsable = (PsExpression, PsSymbol, PsConstant, sp.Expr, int, np.integer)
+
+
 class AstFactory:
     """Factory providing a convenient interface for building syntax trees.
 
@@ -51,6 +56,45 @@ class AstFactory:
         """
         return self._typify(self._freeze(sp_obj))
 
+    @overload
+    def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr:
+        pass
+
+    @overload
+    def parse_index(
+        self, idx: int | np.integer | PsConstant | PsConstantExpr
+    ) -> PsConstantExpr:
+        pass
+
+    @overload
+    def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression:
+        pass
+
+    def parse_index(self, idx: IndexParsable):
+        """Parse the given object as an expression with data type `ctx.index_dtype`."""
+
+        if not isinstance(idx, _IndexParsable):
+            raise TypeError(
+                f"Cannot parse object of type {type(idx)} as an index expression"
+            )
+
+        match idx:
+            case PsExpression():
+                return self._typify.typify_expression(idx, self._ctx.index_dtype)[0]
+            case PsSymbol() | PsConstant():
+                return self._typify.typify_expression(
+                    PsExpression.make(idx), self._ctx.index_dtype
+                )[0]
+            case sp.Expr():
+                return self._typify.typify_expression(
+                    self._freeze(idx), self._ctx.index_dtype
+                )[0]
+            case _:
+                return PsExpression.make(PsConstant(idx, self._ctx.index_dtype))
+
+    def _parse_any_index(self, idx: Any) -> PsExpression:
+        return self.parse_index(cast(IndexParsable, idx))
+
     def parse_slice(
         self, slic: slice, upper_limit: Any | None = None
     ) -> tuple[PsExpression, PsExpression, PsExpression]:
@@ -75,27 +119,16 @@ class AstFactory:
                     "Must specify an upper iteration limit if `slice.stop` is `None` or a negative `int`"
                 )
 
-        def make_expr(val: Any) -> PsExpression:
-            match val:
-                case PsExpression():
-                    return self._typify.typify_expression(val, self._ctx.index_dtype)[0]
-                case PsSymbol() | PsConstant():
-                    return self._typify.typify_expression(
-                        PsExpression.make(val), self._ctx.index_dtype
-                    )[0]
-                case sp.Expr():
-                    return self._typify.typify_expression(
-                        self._freeze(val), self._ctx.index_dtype
-                    )[0]
-                case _:
-                    return PsExpression.make(PsConstant(val, self._ctx.index_dtype))
-
-        start = make_expr(slic.start if slic.start is not None else 0)
-        stop = make_expr(slic.stop) if slic.stop is not None else make_expr(upper_limit)
-        step = make_expr(slic.step if slic.step is not None else 1)
+        start = self._parse_any_index(slic.start if slic.start is not None else 0)
+        stop = (
+            self._parse_any_index(slic.stop)
+            if slic.stop is not None
+            else self._parse_any_index(upper_limit)
+        )
+        step = self._parse_any_index(slic.step if slic.step is not None else 1)
 
         if isinstance(slic.stop, int) and slic.stop < 0:
-            stop = make_expr(upper_limit) + stop
+            stop = self._parse_any_index(upper_limit) + stop
 
         return start, stop, step
 
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index d48953a5b2ccaf1325006f09ec63f5f8fea90f26..263c2f48ecfff15dd4c6271f77ca2a7578b86d09 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -1,9 +1,10 @@
 from __future__ import annotations
 
 from typing import Iterable, Iterator
-from itertools import chain
+from itertools import chain, count
 from types import EllipsisType
-from collections import namedtuple
+from collections import namedtuple, defaultdict
+import re
 
 from ...defaults import DEFAULTS
 from ...field import Field, FieldType
@@ -67,6 +68,9 @@ class KernelCreationContext:
 
         self._symbols: dict[str, PsSymbol] = dict()
 
+        self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
+        self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
+
         self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
         self._fields_collection = FieldsInKernel()
 
@@ -95,6 +99,21 @@ class KernelCreationContext:
     #   Symbols
 
     def get_symbol(self, name: str, dtype: PsType | None = None) -> PsSymbol:
+        """Retrieve the symbol with the given name and data type from the symbol table.
+
+        If no symbol named ``name`` exists, a new symbol with the given data type is created.
+
+        If a symbol with the given ``name`` already exists and ``dtype`` is not `None`,
+        the given data type will be applied to it, and it is returned.
+        If the symbol already has a different data type, an error will be raised.
+
+        If the symbol already exists and ``dtype`` is `None`, the existing symbol is returned
+        without checking or altering its data type.
+
+        Args:
+            name: The symbol's name
+            dtype: The symbol's data type, or `None`
+        """
         if name not in self._symbols:
             symb = PsSymbol(name, None)
             self._symbols[name] = symb
@@ -115,12 +134,20 @@ class KernelCreationContext:
         return self._symbols.get(name, None)
 
     def add_symbol(self, symbol: PsSymbol):
+        """Add an existing symbol to the symbol table.
+
+        If a symbol with the same name already exists, an error will be raised.
+        """
         if symbol.name in self._symbols:
             raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}")
 
         self._symbols[symbol.name] = symbol
 
     def replace_symbol(self, old: PsSymbol, new: PsSymbol):
+        """Replace one symbol by another.
+
+        The two symbols ``old`` and ``new`` must have the same name, but may have different data types.
+        """
         if old.name != new.name:
             raise PsInternalCompilerError(
                 "replace_symbol: Old and new symbol must have the same name"
@@ -131,8 +158,30 @@ class KernelCreationContext:
 
         self._symbols[old.name] = new
 
+    def duplicate_symbol(self, symb: PsSymbol) -> PsSymbol:
+        """Canonically duplicates the given symbol.
+
+        A new symbol with the same data type, and new name ``symb.name + "__<counter>"`` is created,
+        added to the symbol table, and returned.
+        The ``counter`` reflects the number of previously created duplicates of this symbol.
+        """
+        if (result := self._symbol_ctr_pattern.search(symb.name)) is not None:
+            span = result.span()
+            basename = symb.name[: span[0]]
+        else:
+            basename = symb.name
+
+        initial_count = self._symbol_dup_table[basename]
+        for i in count(initial_count):
+            dup_name = f"{basename}__{i}"
+            if self.find_symbol(dup_name) is None:
+                self._symbol_dup_table[basename] = i + 1
+                return self.get_symbol(dup_name, symb.dtype)
+        assert False, "unreachable code"
+
     @property
     def symbols(self) -> Iterable[PsSymbol]:
+        """Return an iterable of all symbols listed in the symbol table."""
         return self._symbols.values()
 
     #   Fields and Arrays
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index ba215f822ea7372211bf764425d44e44487cc46b..6adac2a519ffc04505c0e0adac3484d78f30d013 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace):
     @staticmethod
     def create_from_slice(
         ctx: KernelCreationContext,
-        iteration_slice: Sequence[slice],
+        iteration_slice: slice | Sequence[slice],
         archetype_field: Field | None = None,
     ):
         """Create an iteration space from a sequence of slices, optionally over an archetype field.
@@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace):
             iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice`
             archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order.
         """
+        if isinstance(iteration_slice, slice):
+            iteration_slice = (iteration_slice,)
+
         dim = len(iteration_slice)
         if dim == 0:
             raise ValueError(
diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py
index 01b69509991eaa762a093f50f427f6e4050dc34a..518c402d27e8828cc129c06a20bd10e2ba3d3168 100644
--- a/src/pystencils/backend/transformations/__init__.py
+++ b/src/pystencils/backend/transformations/__init__.py
@@ -1,16 +1,94 @@
+"""
+This module contains various transformation and optimization passes that can be
+executed on the backend AST.
+
+Canonical Form
+==============
+
+Many transformations in this module require that their input AST is in *canonical form*.
+This means that:
+
+- Each symbol, constant, and expression node is annotated with a data type;
+- Each symbol has at most one declaration;
+- Each symbol that is never written to apart from its declaration has a ``const`` type; and
+- Each symbol whose type is *not* ``const`` has at least one non-declaring assignment.
+
+The first requirement can be ensured by running the `Typifier` on each newly constructed subtree.
+The other three requirements are ensured by the `CanonicalizeSymbols` pass,
+which should be run first before applying any optimizing transformations.
+All transformations in this module retain canonicality of the AST.
+
+Canonicality allows transformations to forego various checks that would otherwise be necessary
+to prove their legality.
+
+Certain transformations, like the auto-vectorizer (TODO), state additional requirements, e.g.
+the absence of loop-carried dependencies.
+
+Transformations
+===============
+
+Canonicalization
+----------------
+
+.. autoclass:: CanonicalizeSymbols
+    :members: __call__
+
+AST Cloning
+-----------
+
+.. autoclass:: CanonicalClone
+    :members: __call__
+
+Simplifying Transformations
+---------------------------
+
+.. autoclass:: EliminateConstants
+    :members: __call__
+
+.. autoclass:: EliminateBranches
+    :members: __call__
+
+Code Motion
+-----------
+
+.. autoclass:: HoistLoopInvariantDeclarations
+    :members: __call__
+
+Loop Reshaping Transformations
+------------------------------
+
+.. autoclass:: ReshapeLoops
+    :members:
+
+
+Code Lowering and Materialization
+---------------------------------
+
+.. autoclass:: EraseAnonymousStructTypes
+    :members: __call__
+
+.. autoclass:: SelectFunctions
+    :members: __call__
+
+"""
+
+from .canonicalize_symbols import CanonicalizeSymbols
+from .canonical_clone import CanonicalClone
 from .eliminate_constants import EliminateConstants
 from .eliminate_branches import EliminateBranches
-from .canonicalize_symbols import CanonicalizeSymbols
 from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
+from .reshape_loops import ReshapeLoops
 from .erase_anonymous_structs import EraseAnonymousStructTypes
 from .select_functions import SelectFunctions
 from .select_intrinsics import MaterializeVectorIntrinsics
 
 __all__ = [
+    "CanonicalizeSymbols",
+    "CanonicalClone",
     "EliminateConstants",
     "EliminateBranches",
-    "CanonicalizeSymbols",
     "HoistLoopInvariantDeclarations",
+    "ReshapeLoops",
     "EraseAnonymousStructTypes",
     "SelectFunctions",
     "MaterializeVectorIntrinsics",
diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py
new file mode 100644
index 0000000000000000000000000000000000000000..538bb2779314fc0fe1d7b83810dd6a4b031ca46a
--- /dev/null
+++ b/src/pystencils/backend/transformations/canonical_clone.py
@@ -0,0 +1,112 @@
+from typing import TypeVar, cast
+
+from ..kernelcreation import KernelCreationContext
+from ..symbols import PsSymbol
+from ..exceptions import PsInternalCompilerError
+
+from ..ast import PsAstNode
+from ..ast.structural import (
+    PsBlock,
+    PsConditional,
+    PsLoop,
+    PsDeclaration,
+    PsAssignment,
+    PsComment,
+)
+from ..ast.expressions import PsExpression, PsSymbolExpr
+
+__all__ = ["CanonicalClone"]
+
+
+class CloneContext:
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+        self._dup_table: dict[PsSymbol, PsSymbol] = dict()
+
+    def symbol_decl(self, declared_symbol: PsSymbol):
+        self._dup_table[declared_symbol] = self._ctx.duplicate_symbol(declared_symbol)
+
+    def get_replacement(self, symb: PsSymbol):
+        return self._dup_table.get(symb, symb)
+
+
+Node_T = TypeVar("Node_T", bound=PsAstNode)
+
+
+class CanonicalClone:
+    """Clone a subtree, and rename all symbols declared inside it to retain canonicality."""
+
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+
+    def __call__(self, node: Node_T) -> Node_T:
+        return self.visit(node, CloneContext(self._ctx))
+
+    def visit(self, node: Node_T, cc: CloneContext) -> Node_T:
+        match node:
+            case PsBlock(statements):
+                return cast(
+                    Node_T, PsBlock([self.visit(stmt, cc) for stmt in statements])
+                )
+
+            case PsLoop(ctr, start, stop, step, body):
+                cc.symbol_decl(ctr.symbol)
+                return cast(
+                    Node_T,
+                    PsLoop(
+                        self.visit(ctr, cc),
+                        self.visit(start, cc),
+                        self.visit(stop, cc),
+                        self.visit(step, cc),
+                        self.visit(body, cc),
+                    ),
+                )
+
+            case PsConditional(cond, then, els):
+                return cast(
+                    Node_T,
+                    PsConditional(
+                        self.visit(cond, cc),
+                        self.visit(then, cc),
+                        self.visit(els, cc) if els is not None else None,
+                    ),
+                )
+
+            case PsComment():
+                return cast(Node_T, node.clone())
+
+            case PsDeclaration(lhs, rhs):
+                cc.symbol_decl(node.declared_symbol)
+                return cast(
+                    Node_T,
+                    PsDeclaration(
+                        cast(PsSymbolExpr, self.visit(lhs, cc)),
+                        self.visit(rhs, cc),
+                    ),
+                )
+
+            case PsAssignment(lhs, rhs):
+                return cast(
+                    Node_T,
+                    PsAssignment(
+                        self.visit(lhs, cc),
+                        self.visit(rhs, cc),
+                    ),
+                )
+
+            case PsExpression():
+                expr_clone = node.clone()
+                self._replace_symbols(expr_clone, cc)
+                return cast(Node_T, expr_clone)
+
+            case _:
+                raise PsInternalCompilerError(
+                    f"Don't know how to canonically clone {type(node)}"
+                )
+
+    def _replace_symbols(self, expr: PsExpression, cc: CloneContext):
+        if isinstance(expr, PsSymbolExpr):
+            expr.symbol = cc.get_replacement(expr.symbol)
+        else:
+            for c in expr.children:
+                self._replace_symbols(cast(PsExpression, c), cc)
diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py
index 6fe922f28cfbfd0a42c77dbd3d1606c910abf298..3900105b8f64b7cd33e154de02eaec7cf826d0fb 100644
--- a/src/pystencils/backend/transformations/canonicalize_symbols.py
+++ b/src/pystencils/backend/transformations/canonicalize_symbols.py
@@ -1,5 +1,3 @@
-from itertools import count
-
 from ..kernelcreation import KernelCreationContext
 from ..symbols import PsSymbol
 from ..exceptions import PsInternalCompilerError
@@ -29,14 +27,10 @@ class CanonContext:
             self.live_symbols_map[symb] = symb
             return symb
         else:
-            for i in count():
-                replacement_name = f"{symb.name}__{i}"
-                if self._ctx.find_symbol(replacement_name) is None:
-                    replacement = self._ctx.get_symbol(replacement_name, symb.dtype)
-                    self.live_symbols_map[symb] = replacement
-                    self.encountered_symbols.add(replacement)
-                    return replacement
-            assert False, "unreachable code"
+            replacement = self._ctx.duplicate_symbol(symb)
+            self.live_symbols_map[symb] = replacement
+            self.encountered_symbols.add(replacement)
+            return replacement
 
     def mark_as_updated(self, symb: PsSymbol):
         self.updated_symbols.add(self.deduplicate(symb))
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index 7678dbd8c6ce783585fb7095b201e9f92e65e485..7fa4766eb305954f56d10b8cf8052c2fb26cb8fe 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -1,4 +1,4 @@
-from typing import cast, Iterable
+from typing import cast, Iterable, overload
 from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext, Typifier
@@ -116,6 +116,14 @@ class EliminateConstants:
         self._fold_floats = False
         self._extract_constant_exprs = extract_constant_exprs
 
+    @overload
+    def __call__(self, node: PsExpression) -> PsExpression:
+        pass
+
+    @overload
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        pass
+
     def __call__(self, node: PsAstNode) -> PsAstNode:
         ecc = ECContext(self._ctx)
 
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index 5920038150ee6c5295c3f46f4530639ed02fca25..5824239e40ff8365a20658defab344892973f58f 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -69,7 +69,7 @@ class HoistLoopInvariantDeclarations:
     in particular, each symbol may have at most one declaration.
     To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`.
 
-    `HoistLoopInvariants` assumes that all `PsMathFunction`s are pure (have no side effects),
+    `HoistLoopInvariantDeclarations` assumes that all `PsMathFunction` s are pure (have no side effects),
     but makes no such assumption about instances of `CFunction`.
     """
 
diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py
new file mode 100644
index 0000000000000000000000000000000000000000..6963bee0b2e43bc6bac58a6c96de5f4a35e57148
--- /dev/null
+++ b/src/pystencils/backend/transformations/reshape_loops.py
@@ -0,0 +1,138 @@
+from typing import Sequence
+
+from ..kernelcreation import KernelCreationContext, Typifier
+from ..kernelcreation.ast_factory import AstFactory, IndexParsable
+
+from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
+from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
+from ..constants import PsConstant
+
+from .canonical_clone import CanonicalClone, CloneContext
+from .eliminate_constants import EliminateConstants
+
+
+class ReshapeLoops:
+    """Various transformations for reshaping loop nests."""
+
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+        self._typify = Typifier(ctx)
+        self._factory = AstFactory(ctx)
+        self._canon_clone = CanonicalClone(ctx)
+        self._elim_constants = EliminateConstants(ctx)
+
+    def peel_loop_front(
+        self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
+    ) -> tuple[Sequence[PsBlock], PsLoop]:
+        """Peel off iterations from the front of a loop.
+
+        Removes ``num_iterations`` from the front of the given loop and returns them as a sequence of
+        independent blocks.
+
+        Args:
+            loop: The loop node from which to peel iterations
+            num_iterations: The number of iterations to peel off
+            omit_range_check: If set to `True`, assume that the peeled-off iterations will always
+              be executed, and omit their enclosing conditional.
+
+        Returns:
+            Tuple containing the peeled-off iterations as a sequence of blocks,
+            and the remaining loop.
+        """
+
+        peeled_iters: list[PsBlock] = []
+
+        for i in range(num_iterations):
+            cc = CloneContext(self._ctx)
+            cc.symbol_decl(loop.counter.symbol)
+            peeled_ctr = self._factory.parse_index(
+                cc.get_replacement(loop.counter.symbol)
+            )
+            peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i)))
+
+            counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
+            peeled_block = self._canon_clone.visit(loop.body, cc)
+
+            if omit_range_check:
+                peeled_block.statements = [counter_decl] + peeled_block.statements
+            else:
+                iter_condition = PsLt(peeled_ctr, loop.stop)
+                peeled_block.statements = [
+                    counter_decl,
+                    PsConditional(iter_condition, PsBlock(peeled_block.statements)),
+                ]
+
+            peeled_iters.append(peeled_block)
+
+        loop.start = self._elim_constants(
+            self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
+        )
+
+        return peeled_iters, loop
+
+    def cut_loop(
+        self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
+    ) -> Sequence[PsLoop | PsBlock]:
+        """Cut a loop at the given cutting points.
+
+        Cut the given loop at the iterations specified by the given cutting points,
+        producing ``n`` new subtrees representing the iterations
+        ``(loop.start:cutting_points[0]), (cutting_points[0]:cutting_points[1]), ..., (cutting_points[-1]:loop.stop)``.
+
+        Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are
+        returned without the trivial loop structure.
+
+        Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside
+        the loop's iteration range.
+
+        Returns:
+            Sequence of ``n`` subtrees representing the respective iteration ranges
+        """
+
+        if not (
+            isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
+        ):
+            raise NotImplementedError(
+                "Loop cutting for loops with step != 1 is not implemented"
+            )
+
+        result: list[PsLoop | PsBlock] = []
+        new_start = loop.start
+        cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [
+            loop.stop
+        ]
+
+        for new_end in cutting_points:
+            if new_end.structurally_equal(new_start):
+                continue
+
+            num_iters = self._elim_constants(self._typify(new_end - new_start))
+            skip = False
+
+            if isinstance(num_iters, PsConstantExpr):
+                if num_iters.constant.value == 0:
+                    skip = True
+                elif num_iters.constant.value == 1:
+                    skip = True
+                    cc = CloneContext(self._ctx)
+                    cc.symbol_decl(loop.counter.symbol)
+                    local_counter = self._factory.parse_index(
+                        cc.get_replacement(loop.counter.symbol)
+                    )
+                    ctr_decl = PsDeclaration(
+                        local_counter,
+                        new_start,
+                    )
+                    cloned_body = self._canon_clone.visit(loop.body, cc)
+                    cloned_body.statements = [ctr_decl] + cloned_body.statements
+                    result.append(cloned_body)
+
+            if not skip:
+                loop_clone = self._canon_clone(loop)
+                loop_clone.start = new_start.clone()
+                loop_clone.stop = new_end.clone()
+                result.append(loop_clone)
+
+            new_start = new_end
+
+        return result
diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py
index 1dfcfea2a9bfd8e9db70dab7ca61732523855843..7fd6d778ff62f7fb2fcbc24a55af5225fb9f870e 100644
--- a/tests/nbackend/kernelcreation/test_iteration_space.py
+++ b/tests/nbackend/kernelcreation/test_iteration_space.py
@@ -52,7 +52,7 @@ def test_invalid_slices():
     ctx.add_field(archetype_field)
 
     islice = (slice(1, -1, 0.5),)
-    with pytest.raises(PsTypeError):
+    with pytest.raises(TypeError):
         FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
 
     islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),)
diff --git a/tests/nbackend/transformations/test_canonical_clone.py b/tests/nbackend/transformations/test_canonical_clone.py
new file mode 100644
index 0000000000000000000000000000000000000000..b158b91781b49f8d589a3da3b266e8c2137fceab
--- /dev/null
+++ b/tests/nbackend/transformations/test_canonical_clone.py
@@ -0,0 +1,63 @@
+import sympy as sp
+from pystencils import Field, Assignment, make_slice, TypedSymbol
+from pystencils.types.quick import Arr
+
+from pystencils.backend.kernelcreation import (
+    KernelCreationContext,
+    AstFactory,
+    FullIterationSpace,
+)
+from pystencils.backend.transformations import CanonicalClone
+from pystencils.backend.ast.structural import PsBlock, PsComment
+from pystencils.backend.ast.expressions import PsSymbolExpr
+from pystencils.backend.ast.iteration import dfs_preorder
+
+
+def test_clone_entire_ast():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    canon_clone = CanonicalClone(ctx)
+
+    f = Field.create_generic("f", 2, index_shape=(5,))
+    rho = sp.Symbol("rho")
+    u = sp.symbols("u_:2")
+
+    cx = TypedSymbol("cx", Arr(ctx.default_dtype))
+    cy = TypedSymbol("cy", Arr(ctx.default_dtype))
+    cxs = sp.IndexedBase(cx, shape=(5,))
+    cys = sp.IndexedBase(cy, shape=(5,))
+
+    rho_out = Field.create_generic("rho", 2, index_shape=(1,))
+    u_out = Field.create_generic("u", 2, index_shape=(2,))
+
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
+    ctx.set_iteration_space(ispace)
+
+    asms = [
+        Assignment(cx, (0, 1, -1, 0, 0)),
+        Assignment(cy, (0, 0, 0, 1, -1)),
+        Assignment(rho, sum(f.center(i) for i in range(5))),
+        Assignment(u[0], 1 / rho * sum((f.center(i) * cxs[i]) for i in range(5))),
+        Assignment(u[1], 1 / rho * sum((f.center(i) * cys[i]) for i in range(5))),
+        Assignment(rho_out.center(0), rho),
+        Assignment(u_out.center(0), u[0]),
+        Assignment(u_out.center(1), u[1]),
+    ]
+
+    body = PsBlock(
+        [PsComment("Compute and export density and velocity")]
+        + [factory.parse_sympy(asm) for asm in asms]
+    )
+
+    ast = factory.loops_from_ispace(ispace, body)
+    ast_clone = canon_clone(ast)
+
+    for orig, clone in zip(dfs_preorder(ast), dfs_preorder(ast_clone), strict=True):
+        assert type(orig) is type(clone)
+        assert orig is not clone
+
+        if isinstance(orig, PsSymbolExpr):
+            assert isinstance(clone, PsSymbolExpr)
+
+            if orig.symbol.name in ("ctr_0", "ctr_1", "rho", "u_0", "u_1", "cx", "cy"):
+                assert clone.symbol.name == orig.symbol.name + "__0"
diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py
new file mode 100644
index 0000000000000000000000000000000000000000..e68cff1b64acbb4f9bbf30dee9ef3f2abe9e59d3
--- /dev/null
+++ b/tests/nbackend/transformations/test_reshape_loops.py
@@ -0,0 +1,101 @@
+import sympy as sp
+
+from pystencils import Field, Assignment, make_slice
+from pystencils.backend.kernelcreation import (
+    KernelCreationContext,
+    AstFactory,
+    FullIterationSpace,
+)
+from pystencils.backend.transformations import ReshapeLoops
+
+from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional
+from pystencils.backend.ast.expressions import PsConstantExpr, PsLt
+
+
+def test_loop_cutting():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    reshape = ReshapeLoops(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+
+    f = Field.create_generic("f", 1, index_shape=(2,))
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
+    ctx.set_iteration_space(ispace)
+
+    loop_body = PsBlock(
+        [
+            factory.parse_sympy(Assignment(x, 2 * z)),
+            factory.parse_sympy(Assignment(f.center(0), x + y)),
+        ]
+    )
+
+    loop = factory.loops_from_ispace(ispace, loop_body)
+
+    subloops = reshape.cut_loop(loop, [1, 1, 3])
+    assert len(subloops) == 3
+
+    subloop = subloops[0]
+    assert isinstance(subloop, PsBlock)
+    assert isinstance(subloop.statements[0], PsDeclaration)
+    assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
+
+    x_decl = subloop.statements[1]
+    assert isinstance(x_decl, PsDeclaration)
+    assert x_decl.declared_symbol.name == "x__0"
+    
+    subloop = subloops[1]
+    assert isinstance(subloop, PsLoop)
+    assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
+    assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
+
+    x_decl = subloop.body.statements[0]
+    assert isinstance(x_decl, PsDeclaration)
+    assert x_decl.declared_symbol.name == "x__1"
+
+    subloop = subloops[2]
+    assert isinstance(subloop, PsLoop)
+    assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
+    assert subloop.stop.structurally_equal(loop.stop)
+
+
+def test_loop_peeling():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+    reshape = ReshapeLoops(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+
+    f = Field.create_generic("f", 1, index_shape=(2,))
+    ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
+    ctx.set_iteration_space(ispace)
+
+    loop_body = PsBlock([
+        factory.parse_sympy(Assignment(x, 2 * z)),
+        factory.parse_sympy(Assignment(f.center(0), x + y)),
+    ])
+
+    loop = factory.loops_from_ispace(ispace, loop_body)
+
+    num_iters = 3
+    peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
+    assert len(peeled_iters) == 3
+
+    for i, iter in enumerate(peeled_iters):
+        assert isinstance(iter, PsBlock)
+
+        ctr_decl = iter.statements[0]
+        assert isinstance(ctr_decl, PsDeclaration)
+        assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
+
+        cond = iter.statements[1]
+        assert isinstance(cond, PsConditional)
+        assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
+
+        subblock = cond.branch_true
+        assert isinstance(subblock.statements[0], PsDeclaration)
+        assert subblock.statements[0].declared_symbol.name == f"x__{i}"
+
+    assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
+    assert peeled_loop.stop.structurally_equal(loop.stop)
+    assert peeled_loop.body.structurally_equal(loop.body)