diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py
index 15ee0680edb0b5b8197aec2545d182f90eb6c71a..5b37470cc89dc8b304d9fd033ea66ec67b9f8d8a 100644
--- a/src/pystencils/backend/ast/analysis.py
+++ b/src/pystencils/backend/ast/analysis.py
@@ -1,5 +1,7 @@
+from dataclasses import dataclass
 from typing import cast
 from functools import reduce
+import operator
 
 from .structural import (
     PsAssignment,
@@ -12,11 +14,28 @@ from .structural import (
     PsLoop,
     PsStatement,
 )
-from .expressions import PsSymbolExpr, PsConstantExpr
+from .expressions import (
+    PsAdd,
+    PsArrayAccess,
+    PsCall,
+    PsConstantExpr,
+    PsDiv,
+    PsIntDiv,
+    PsLiteralExpr,
+    PsMul,
+    PsNeg,
+    PsRem,
+    PsSub,
+    PsSymbolExpr,
+    PsTernary,
+)
 
 from ..symbols import PsSymbol
 from ..exceptions import PsInternalCompilerError
 
+from ...types import PsNumericType
+from ...types.exception import PsTypeError
+
 
 class UndefinedSymbolsCollector:
     """Collect undefined symbols.
@@ -120,3 +139,212 @@ def collect_required_headers(node: PsAstNode) -> set[str]:
             return reduce(
                 set.union, (collect_required_headers(c) for c in node.children), set()
             )
+
+
+@dataclass
+class OperationCounts:
+    float_adds: int = 0
+    float_muls: int = 0
+    float_divs: int = 0
+    int_adds: int = 0
+    int_muls: int = 0
+    int_divs: int = 0
+    calls: int = 0
+    branches: int = 0
+    loops_with_dynamic_bounds: int = 0
+
+    def __add__(self, other):
+        if not isinstance(other, OperationCounts):
+            return NotImplemented
+
+        return OperationCounts(
+            float_adds=self.float_adds + other.float_adds,
+            float_muls=self.float_muls + other.float_muls,
+            float_divs=self.float_divs + other.float_divs,
+            int_adds=self.int_adds + other.int_adds,
+            int_muls=self.int_muls + other.int_muls,
+            int_divs=self.int_divs + other.int_divs,
+            calls=self.calls + other.calls,
+            branches=self.branches + other.branches,
+            loops_with_dynamic_bounds=self.loops_with_dynamic_bounds
+            + other.loops_with_dynamic_bounds,
+        )
+
+    def __rmul__(self, other):
+        if not isinstance(other, int):
+            return NotImplemented
+
+        return OperationCounts(
+            float_adds=other * self.float_adds,
+            float_muls=other * self.float_muls,
+            float_divs=other * self.float_divs,
+            int_adds=other * self.int_adds,
+            int_muls=other * self.int_muls,
+            int_divs=other * self.int_divs,
+            calls=other * self.calls,
+            branches=other * self.branches,
+            loops_with_dynamic_bounds=other * self.loops_with_dynamic_bounds,
+        )
+
+
+class OperationCounter:
+    """Counts the number of operations in an AST.
+
+    Assumes that the AST is typed. It is recommended that constant folding is
+    applied prior to this pass.
+
+    The counted operations are:
+      - Additions, multiplications and divisions of floating and integer type.
+        The counts of either type are reported separately and operations on
+        other types are ignored.
+      - Function calls.
+      - Branches.
+        Includes `PsConditional` and `PsTernary`. The operations in all branches
+        are summed up (i.e. the result is an overestimation).
+      - Loops with an unknown number of iterations.
+        The operations in the loop header and body are counted exactly once,
+        i.e. it is assumed that there is one loop iteration.
+
+    If the start, stop and step of the loop are `PsConstantExpr`, then any
+    operation within the body is multiplied by the number of iterations.
+    """
+
+    def __call__(self, node: PsAstNode) -> OperationCounts:
+        """Counts the number of operations in the given AST."""
+        return self.visit(node)
+
+    def visit(self, node: PsAstNode) -> OperationCounts:
+        match node:
+            case PsExpression():
+                return self.visit_expr(node)
+
+            case PsStatement(expr):
+                return self.visit_expr(expr)
+
+            case PsAssignment(lhs, rhs):
+                return self.visit_expr(lhs) + self.visit_expr(rhs)
+
+            case PsBlock(statements):
+                return reduce(
+                    operator.add, (self.visit(s) for s in statements), OperationCounts()
+                )
+
+            case PsLoop(_, start, stop, step, body):
+                if (
+                    isinstance(start, PsConstantExpr)
+                    and isinstance(stop, PsConstantExpr)
+                    and isinstance(step, PsConstantExpr)
+                ):
+                    val_start = start.constant.value
+                    val_stop = stop.constant.value
+                    val_step = step.constant.value
+
+                    if (val_stop - val_start) % val_step == 0:
+                        iteration_count = max(0, int((val_stop - val_start) / val_step))
+                    else:
+                        iteration_count = max(
+                            0, int((val_stop - val_start) / val_step) + 1
+                        )
+
+                    return self.visit_expr(start) + iteration_count * (
+                        OperationCounts(int_adds=1)  # loop counter increment
+                        + self.visit_expr(stop)
+                        + self.visit_expr(step)
+                        + self.visit(body)
+                    )
+                else:
+                    return (
+                        OperationCounts(loops_with_dynamic_bounds=1)
+                        + self.visit_expr(start)
+                        + self.visit_expr(stop)
+                        + self.visit_expr(step)
+                        + self.visit(body)
+                    )
+
+            case PsConditional(cond, branch_true, branch_false):
+                op_counts = (
+                    OperationCounts(branches=1)
+                    + self.visit(cond)
+                    + self.visit(branch_true)
+                )
+                if branch_false is not None:
+                    op_counts += self.visit(branch_false)
+                return op_counts
+
+            case PsEmptyLeafMixIn():
+                return OperationCounts()
+
+            case unknown:
+                raise PsInternalCompilerError(f"Can't count operations in {unknown}")
+
+    def visit_expr(self, expr: PsExpression) -> OperationCounts:
+        match expr:
+            case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_):
+                return OperationCounts()
+
+            case PsArrayAccess(_, index):
+                return self.visit_expr(index)
+
+            case PsCall(_, args):
+                return OperationCounts(calls=1) + reduce(
+                    operator.add, (self.visit(a) for a in args), OperationCounts()
+                )
+
+            case PsTernary(cond, then, els):
+                return (
+                    OperationCounts(branches=1)
+                    + self.visit_expr(cond)
+                    + self.visit_expr(then)
+                    + self.visit_expr(els)
+                )
+
+            case PsNeg(arg):
+                if expr.dtype is None:
+                    raise PsTypeError(f"Untyped arithmetic expression: {expr}")
+
+                op_counts = self.visit_expr(arg)
+                if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
+                    op_counts.float_muls += 1
+                elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
+                    op_counts.int_muls += 1
+                return op_counts
+
+            case PsAdd(arg1, arg2) | PsSub(arg1, arg2):
+                if expr.dtype is None:
+                    raise PsTypeError(f"Untyped arithmetic expression: {expr}")
+
+                op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
+                if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
+                    op_counts.float_adds += 1
+                elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
+                    op_counts.int_adds += 1
+                return op_counts
+
+            case PsMul(arg1, arg2):
+                if expr.dtype is None:
+                    raise PsTypeError(f"Untyped arithmetic expression: {expr}")
+
+                op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
+                if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
+                    op_counts.float_muls += 1
+                elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
+                    op_counts.int_muls += 1
+                return op_counts
+
+            case PsDiv(arg1, arg2) | PsIntDiv(arg1, arg2) | PsRem(arg1, arg2):
+                if expr.dtype is None:
+                    raise PsTypeError(f"Untyped arithmetic expression: {expr}")
+
+                op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
+                if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
+                    op_counts.float_divs += 1
+                elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
+                    op_counts.int_divs += 1
+                return op_counts
+
+            case _:
+                return reduce(
+                    operator.add,
+                    (self.visit_expr(cast(PsExpression, c)) for c in expr.children),
+                    OperationCounts(),
+                )
diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 908f31052ed4b8d49a66cd1ce801d9841ef4fb7d..3b76e514e566c9c7113c1bbcaf1b63df71a8cb57 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -226,7 +226,7 @@ class PsSubscript(PsLvalue, PsExpression):
             case 1:
                 self.index = failing_cast(PsExpression, c)
 
-    def __str__(self) -> str:
+    def __repr__(self) -> str:
         return f"Subscript({self._base})[{self._index}]"
 
 
@@ -419,7 +419,7 @@ class PsCall(PsExpression):
         if not isinstance(other, PsCall):
             return False
         return super().structurally_equal(other) and self._function == other._function
-    
+
     def __str__(self):
         args = ", ".join(str(arg) for arg in self._args)
         return f"PsCall({self._function}, ({args}))"
diff --git a/tests/nbackend/test_operation_counter.py b/tests/nbackend/test_operation_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..98b004c4dbc489adc9c0aa86b0d8baccc92bc60c
--- /dev/null
+++ b/tests/nbackend/test_operation_counter.py
@@ -0,0 +1,61 @@
+from pystencils.backend.ast.analysis import OperationCounter
+from pystencils.backend.ast.expressions import (
+    PsAdd,
+    PsConstant,
+    PsDiv,
+    PsExpression,
+    PsMul,
+    PsTernary,
+)
+from pystencils.backend.ast.structural import (
+    PsBlock,
+    PsDeclaration,
+    PsLoop,
+)
+
+from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
+from pystencils.types import PsBoolType
+
+
+def test_count_operations():
+    ctx = KernelCreationContext()
+    typify = Typifier(ctx)
+    counter = OperationCounter()
+
+    x = PsExpression.make(ctx.get_symbol("x"))
+    y = PsExpression.make(ctx.get_symbol("y"))
+    z = PsExpression.make(ctx.get_symbol("z"))
+
+    i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
+    p = PsExpression.make(ctx.get_symbol("p", PsBoolType()))
+
+    zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
+    two = PsExpression.make(PsConstant(2, ctx.index_dtype))
+    five = PsExpression.make(PsConstant(5, ctx.index_dtype))
+
+    ast = PsLoop(
+        i,
+        zero,
+        five,
+        two,
+        PsBlock(
+            [
+                PsDeclaration(x, PsAdd(y, z)),
+                PsDeclaration(y, PsMul(x, PsMul(y, z))),
+                PsDeclaration(z, PsDiv(PsDiv(PsDiv(x, y), z), PsTernary(p, x, y))),
+            ]
+        ),
+    )
+    ast = typify(ast)
+
+    op_count = counter(ast)
+
+    assert op_count.float_adds == 3 * 1
+    assert op_count.float_muls == 3 * 2
+    assert op_count.float_divs == 3 * 3
+    assert op_count.int_adds == 3 * 1
+    assert op_count.int_muls == 0
+    assert op_count.int_divs == 0
+    assert op_count.calls == 0
+    assert op_count.branches == 3 * 1
+    assert op_count.loops_with_dynamic_bounds == 0