diff --git a/src/pystencils/backend/ast/astnode.py b/src/pystencils/backend/ast/astnode.py
index 6680c58e5971409a585e0d9787691f646cabe701..3487d4200700ebd1de03e41e374e0a550553a57f 100644
--- a/src/pystencils/backend/ast/astnode.py
+++ b/src/pystencils/backend/ast/astnode.py
@@ -29,6 +29,10 @@ class PsAstNode(ABC):
     def set_child(self, idx: int, c: PsAstNode):
         pass
 
+    @abstractmethod
+    def clone(self) -> PsAstNode:
+        pass
+
     def structurally_equal(self, other: PsAstNode) -> bool:
         """Check two ASTs for structural equality."""
         return (
diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 4d4cfb457b82b4dffca341e1f961f1d556d99df1..eab309a1d380eba184a4a0393db7d34b2af53468 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from abc import ABC
+from abc import ABC, abstractmethod
 from typing import Sequence, overload
 
 from ..symbols import PsSymbol
@@ -54,10 +54,18 @@ class PsExpression(PsAstNode, ABC):
         else:
             raise ValueError(f"Cannot make expression out of {obj}")
 
+    @abstractmethod
+    def clone(self) -> PsExpression:
+        pass
+
 
 class PsLvalueExpr(PsExpression, ABC):
     """Base class for all expressions that may occur as an lvalue"""
 
+    @abstractmethod
+    def clone(self) -> PsLvalueExpr:
+        pass
+
 
 class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr):
     """A single symbol as an expression."""
@@ -75,6 +83,9 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr):
     def symbol(self, symbol: PsSymbol):
         self._symbol = symbol
 
+    def clone(self) -> PsSymbolExpr:
+        return PsSymbolExpr(self._symbol)
+
     def structurally_equal(self, other: PsAstNode) -> bool:
         if not isinstance(other, PsSymbolExpr):
             return False
@@ -99,6 +110,9 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
     def constant(self, c: PsConstant):
         self._constant = c
 
+    def clone(self) -> PsConstantExpr:
+        return PsConstantExpr(self._constant)
+
     def structurally_equal(self, other: PsAstNode) -> bool:
         if not isinstance(other, PsConstantExpr):
             return False
@@ -132,6 +146,9 @@ class PsSubscript(PsLvalueExpr):
     def index(self, expr: PsExpression):
         self._index = expr
 
+    def clone(self) -> PsSubscript:
+        return PsSubscript(self._base.clone(), self._index.clone())
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._base, self._index)
 
@@ -180,6 +197,9 @@ class PsArrayAccess(PsSubscript):
         """Data type of this expression, i.e. the element type of the underlying array"""
         return self._base_ptr.array.element_type
 
+    def clone(self) -> PsArrayAccess:
+        return PsArrayAccess(self._base_ptr, self._index.clone())
+
     def __repr__(self) -> str:
         return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})"
 
@@ -226,6 +246,15 @@ class PsVectorArrayAccess(PsArrayAccess):
     def alignment(self) -> int:
         return self._alignment
 
+    def clone(self) -> PsVectorArrayAccess:
+        return PsVectorArrayAccess(
+            self._base_ptr,
+            self._index.clone(),
+            self.vector_entries,
+            self._stride,
+            self._alignment,
+        )
+
     def structurally_equal(self, other: PsAstNode) -> bool:
         if not isinstance(other, PsVectorArrayAccess):
             return False
@@ -243,7 +272,7 @@ class PsLookup(PsExpression):
 
     def __init__(self, aggregate: PsExpression, member_name: str) -> None:
         self._aggregate = aggregate
-        self._member = member_name
+        self._member_name = member_name
 
     @property
     def aggregate(self) -> PsExpression:
@@ -255,12 +284,15 @@ class PsLookup(PsExpression):
 
     @property
     def member_name(self) -> str:
-        return self._member
+        return self._member_name
 
     @member_name.setter
     def member_name(self, name: str):
         self._name = name
 
+    def clone(self) -> PsLookup:
+        return PsLookup(self._aggregate.clone(), self._member_name)
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._aggregate,)
 
@@ -298,6 +330,9 @@ class PsCall(PsExpression):
 
         self._args = list(exprs)
 
+    def clone(self) -> PsCall:
+        return PsCall(self._function, [arg.clone() for arg in self._args])
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return self.args
 
@@ -324,6 +359,9 @@ class PsUnOp(PsExpression):
     def operand(self, expr: PsExpression):
         self._operand = expr
 
+    def clone(self) -> PsUnOp:
+        return type(self)(self._operand.clone())
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._operand,)
 
@@ -359,6 +397,9 @@ class PsCast(PsUnOp):
     def target_type(self, dtype: PsType):
         self._target_type = dtype
 
+    def clone(self) -> PsUnOp:
+        return PsCast(self._target_type, self._operand.clone())
+
     def structurally_equal(self, other: PsAstNode) -> bool:
         if not isinstance(other, PsCast):
             return False
@@ -391,6 +432,9 @@ class PsBinOp(PsExpression):
     def operand2(self, expr: PsExpression):
         self._op2 = expr
 
+    def clone(self) -> PsBinOp:
+        return type(self)(self._op1.clone(), self._op2.clone())
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._op1, self._op2)
 
diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index c9da546a829d85436f1efd84a223a5e35b704f0a..2338ee8c4064042827b5017a496755ce985913a8 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -20,6 +20,9 @@ class PsBlock(PsAstNode):
     def set_child(self, idx: int, c: PsAstNode):
         self._statements[idx] = c
 
+    def clone(self) -> PsBlock:
+        return PsBlock([stmt.clone() for stmt in self._statements])
+
     @property
     def statements(self) -> list[PsAstNode]:
         return self._statements
@@ -47,6 +50,9 @@ class PsStatement(PsAstNode):
     def expression(self, expr: PsExpression):
         self._expression = expr
 
+    def clone(self) -> PsStatement:
+        return PsStatement(self._expression.clone())
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._expression,)
 
@@ -82,6 +88,9 @@ class PsAssignment(PsAstNode):
     def rhs(self, expr: PsExpression):
         self._rhs = expr
 
+    def clone(self) -> PsAssignment:
+        return PsAssignment(self._lhs.clone(), self._rhs.clone())
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._lhs, self._rhs)
 
@@ -123,6 +132,9 @@ class PsDeclaration(PsAssignment):
     def declared_variable(self, lvalue: PsSymbolExpr):
         self._lhs = lvalue
 
+    def clone(self) -> PsDeclaration:
+        return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
+
     def set_child(self, idx: int, c: PsAstNode):
         idx = [0, 1][idx]  # trick to normalize index
         if idx == 0:
@@ -193,6 +205,15 @@ class PsLoop(PsAstNode):
     def body(self, block: PsBlock):
         self._body = block
 
+    def clone(self) -> PsLoop:
+        return PsLoop(
+            self._ctr.clone(),
+            self._start.clone(),
+            self._stop.clone(),
+            self._step.clone(),
+            self._body.clone(),
+        )
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._ctr, self._start, self._stop, self._step, self._body)
 
@@ -252,6 +273,13 @@ class PsConditional(PsAstNode):
     def branch_false(self, block: PsBlock | None):
         self._branch_false = block
 
+    def clone(self) -> PsConditional:
+        return PsConditional(
+            self._condition.clone(),
+            self._branch_true.clone(),
+            self._branch_false.clone() if self._branch_false is not None else None,
+        )
+
     def get_children(self) -> tuple[PsAstNode, ...]:
         return (self._condition, self._branch_true) + (
             (self._branch_false,) if self._branch_false is not None else ()
@@ -285,6 +313,9 @@ class PsComment(PsLeafMixIn, PsAstNode):
     def lines(self) -> tuple[str, ...]:
         return self._lines
 
+    def clone(self) -> PsComment:
+        return PsComment(self._text)
+
     def structurally_equal(self, other: PsAstNode) -> bool:
         if not isinstance(other, PsComment):
             return False
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 8b54da651dd81a325612be6abda57e2d56fa3ed5..63dc9170a0efc1e5bb0b626353d290774c584ac8 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -117,7 +117,9 @@ class FreezeExpressions:
         for summand in expr.args:
             if summand.is_negative:
                 signs.append(-1)
-            elif isinstance(summand, sp.Mul) and any(factor.is_negative for factor in summand.args):
+            elif isinstance(summand, sp.Mul) and any(
+                factor.is_negative for factor in summand.args
+            ):
                 signs.append(-1)
             else:
                 signs.append(1)
@@ -126,18 +128,18 @@ class FreezeExpressions:
 
         for sign, arg in zip(signs[1:], expr.args[1:]):
             if sign == -1:
-                arg = - arg
+                arg = -arg
                 op = sub
             else:
                 op = add
 
             frozen_expr = op(frozen_expr, self.visit_expr(arg))
-        
+
         return frozen_expr
 
     def map_Mul(self, expr: sp.Mul) -> PsExpression:
         return reduce(mul, (self.visit_expr(arg) for arg in expr.args))
-    
+
     def map_Pow(self, expr: sp.Pow) -> PsExpression:
         base = expr.args[0]
         exponent = expr.args[1]
@@ -147,18 +149,29 @@ class FreezeExpressions:
         expand_product = False
 
         if exponent.is_Integer:
+            if exponent == 0:
+                return PsExpression.make(PsConstant(1))
+
             if exponent.is_negative:
                 reciprocal = True
-                exponent = - exponent
+                exponent = -exponent
 
-            if exponent <= sp.Integer(5):
+            if exponent <= sp.Integer(
+                5
+            ):  # TODO: is this a sensible limit? maybe make this configurable.
                 expand_product = True
 
         if expand_product:
-            frozen_expr = reduce(mul, [base_frozen] * int(exponent))
+            frozen_expr = reduce(
+                mul,
+                [base_frozen]
+                + [base_frozen.clone() for _ in range(0, int(exponent) - 1)],
+            )
         else:
             exponent_frozen = self.visit_expr(exponent)
-            frozen_expr = PsMathFunction(MathFunctions.Pow)(base_frozen, exponent_frozen)
+            frozen_expr = PsMathFunction(MathFunctions.Pow)(
+                base_frozen, exponent_frozen
+            )
 
         if reciprocal:
             one = PsExpression.make(PsConstant(1))
diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py
index 20d63e85ef802bbe5d5d5a25441fa326b533db77..c38ac60b44dc76007b183ac4e4ce5cfb61554c8b 100644
--- a/src/pystencils/backend/kernelfunction.py
+++ b/src/pystencils/backend/kernelfunction.py
@@ -42,10 +42,10 @@ class KernelParameter:
             type(self) is type(other)
             and self._hashable_contents() == other._hashable_contents()
         )
-    
+
     def __str__(self) -> str:
         return self._name
-    
+
     def __repr__(self) -> str:
         return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
 
@@ -60,7 +60,7 @@ class FieldParameter(KernelParameter, ABC):
     @property
     def field(self):
         return self._field
-    
+
     def _hashable_contents(self):
         return super()._hashable_contents() + (self._field,)
 
@@ -75,7 +75,7 @@ class FieldShapeParam(FieldParameter):
     @property
     def coordinate(self):
         return self._coordinate
-    
+
     def _hashable_contents(self):
         return super()._hashable_contents() + (self._coordinate,)
 
@@ -90,7 +90,7 @@ class FieldStrideParam(FieldParameter):
     @property
     def coordinate(self):
         return self._coordinate
-    
+
     def _hashable_contents(self):
         return super()._hashable_contents() + (self._coordinate,)
 
diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py
index 4fdc0f612019cede0096a1e9c7afca5256c6ca5d..8483977d8a83067757a7677dc88c1fb2db3e003d 100644
--- a/src/pystencils/sympyextensions/astnodes.py
+++ b/src/pystencils/sympyextensions/astnodes.py
@@ -4,7 +4,7 @@ import uuid
 from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
 
 import sympy as sp
-from sympy.codegen.ast import Assignment, AugmentedAssignment
+from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
 from sympy.printing.latex import LatexPrinter
 import numpy as np
 
diff --git a/tests/nbackend/test_ast_nodes.py b/tests/nbackend/test_ast_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb2dd0e04081050138be1c8efc0210f81c999707
--- /dev/null
+++ b/tests/nbackend/test_ast_nodes.py
@@ -0,0 +1,66 @@
+from pystencils.backend.symbols import PsSymbol
+from pystencils.backend.constants import PsConstant
+from pystencils.backend.ast.expressions import (
+    PsExpression,
+    PsCast,
+    PsDeref,
+    PsSubscript,
+)
+from pystencils.backend.ast.structural import (
+    PsStatement,
+    PsAssignment,
+    PsBlock,
+    PsConditional,
+    PsComment,
+    PsLoop,
+)
+from pystencils.types.quick import Fp, Ptr
+
+
+def test_cloning():
+    x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"]
+    c1 = PsExpression.make(PsConstant(3.0))
+    c2 = PsExpression.make(PsConstant(-1.0))
+    one = PsExpression.make(PsConstant(1))
+
+    def check(orig, clone):
+        assert not (orig is clone)
+        assert type(orig) is type(clone)
+        assert orig.structurally_equal(clone)
+
+        for c1, c2 in zip(orig.children, clone.children, strict=True):
+            check(c1, c2)
+
+    for ast in [
+        x,
+        y,
+        c1,
+        x + y,
+        x / y + c1,
+        c1 + c2,
+        PsStatement(x * y * z + c1),
+        PsAssignment(y, x / c1),
+        PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]),
+        PsConditional(
+            y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
+        ),
+        PsLoop(
+            x,
+            y,
+            z,
+            one,
+            PsBlock(
+                [
+                    PsComment("Loop body"),
+                    PsAssignment(x, y),
+                    PsAssignment(x, y),
+                    PsStatement(
+                        PsDeref(PsCast(Ptr(Fp(32)), z))
+                        + PsSubscript(z, one + one + one)
+                    ),
+                ]
+            ),
+        ),
+    ]:
+        ast_clone = ast.clone()
+        check(ast, ast_clone)