From 04fffe916365a3746c955def164a250ef54d6eb0 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 27 Feb 2024 16:30:46 +0100
Subject: [PATCH] Removal of pymbolic; Introduction of native expression tree.

---
 conftest.py                                   |   5 +-
 doc/sphinx/nbackend/rationale.rst             |   5 +-
 mypy.ini                                      |  11 +-
 pyproject.toml                                |   2 +-
 src/pystencils/__init__.py                    |  69 ++-
 src/pystencils/backend/arrays.py              | 122 +----
 src/pystencils/backend/ast/__init__.py        |  28 +-
 src/pystencils/backend/ast/analysis.py        | 105 +++++
 src/pystencils/backend/ast/astnode.py         |  55 +++
 src/pystencils/backend/ast/collectors.py      | 134 ------
 src/pystencils/backend/ast/dispatcher.py      |  41 --
 src/pystencils/backend/ast/expressions.py     | 423 ++++++++++++++++++
 .../ast/{tree_iteration.py => iteration.py}   |   2 +-
 src/pystencils/backend/ast/kernelfunction.py  |  40 +-
 .../backend/ast/{nodes.py => structural.py}   | 122 +----
 src/pystencils/backend/ast/transformations.py |  58 ---
 src/pystencils/backend/constants.py           |  53 +++
 src/pystencils/backend/constraints.py         |  19 +-
 src/pystencils/backend/emission.py            | 387 +++++++++++-----
 src/pystencils/backend/functions.py           |  52 +--
 .../backend/jit/cpu_extension_module.py       |  47 +-
 src/pystencils/backend/jit/legacy_cpu.py      |   1 +
 .../backend/kernelcreation/__init__.py        |   5 +-
 .../backend/kernelcreation/context.py         |  54 ++-
 .../backend/kernelcreation/freeze.py          | 137 ++++--
 .../backend/kernelcreation/iteration_space.py |  80 ++--
 .../backend/kernelcreation/typification.py    | 270 ++++++-----
 .../backend/platforms/generic_cpu.py          |  46 +-
 src/pystencils/backend/platforms/platform.py  |   2 +-
 src/pystencils/backend/platforms/x86.py       |  21 +-
 src/pystencils/backend/symbols.py             |  53 +++
 .../erase_anonymous_structs.py                |  74 +--
 .../transformations/vector_intrinsics.py      | 149 +++---
 src/pystencils/backend/typed_expressions.py   | 244 ----------
 src/pystencils/config.py                      |   4 +-
 .../{backend/kernelcreation => }/defaults.py  |  27 +-
 src/pystencils/kernelcreation.py              |   2 +-
 .../kernelcreation/platform/test_basic_cpu.py |  12 +-
 tests/nbackend/kernelcreation/test_freeze.py  |  41 +-
 .../kernelcreation/test_iteration_space.py    |  69 +--
 .../kernelcreation/test_typification.py       |  47 +-
 tests/nbackend/test_basic_printing.py         |  43 --
 tests/nbackend/test_code_printing.py          |  77 ++++
 tests/nbackend/test_constant_folding.py       |  79 +---
 tests/nbackend/test_cpujit.py                 |  49 +-
 tests/nbackend/test_expressions.py            |  51 ---
 tests/nbackend/types/test_constants.py        | 111 ++---
 47 files changed, 1884 insertions(+), 1644 deletions(-)
 create mode 100644 src/pystencils/backend/ast/analysis.py
 create mode 100644 src/pystencils/backend/ast/astnode.py
 delete mode 100644 src/pystencils/backend/ast/collectors.py
 delete mode 100644 src/pystencils/backend/ast/dispatcher.py
 create mode 100644 src/pystencils/backend/ast/expressions.py
 rename src/pystencils/backend/ast/{tree_iteration.py => iteration.py} (96%)
 rename src/pystencils/backend/ast/{nodes.py => structural.py} (69%)
 delete mode 100644 src/pystencils/backend/ast/transformations.py
 create mode 100644 src/pystencils/backend/constants.py
 create mode 100644 src/pystencils/backend/symbols.py
 delete mode 100644 src/pystencils/backend/typed_expressions.py
 rename src/pystencils/{backend/kernelcreation => }/defaults.py (55%)
 delete mode 100644 tests/nbackend/test_basic_printing.py
 create mode 100644 tests/nbackend/test_code_printing.py
 delete mode 100644 tests/nbackend/test_expressions.py

diff --git a/conftest.py b/conftest.py
index ca7c153b4..ef075c534 100644
--- a/conftest.py
+++ b/conftest.py
@@ -3,6 +3,7 @@ import runpy
 import sys
 import tempfile
 import warnings
+import pathlib
 
 import nbformat
 import pytest
@@ -137,7 +138,7 @@ class IPyNbFile(pytest.File):
         exporter.exclude_markdown = True
         exporter.exclude_input_prompt = True
 
-        notebook_contents = self.fspath.open(encoding='utf-8')
+        notebook_contents = self.path.open(encoding='utf-8')
 
         with warnings.catch_warnings():
             warnings.filterwarnings("ignore", "IPython.core.inputsplitter is deprecated")
@@ -156,6 +157,6 @@ def pytest_collect_file(path, parent):
     glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
     if any(path.fnmatch(g) for g in glob_exprs):
         if pytest_version >= 50403:
-            return IPyNbFile.from_parent(fspath=path, parent=parent)
+            return IPyNbFile.from_parent(path=pathlib.Path(path), parent=parent)
         else:
             return IPyNbFile(path, parent)
diff --git a/doc/sphinx/nbackend/rationale.rst b/doc/sphinx/nbackend/rationale.rst
index 39c4e0d1f..a69b73cd5 100644
--- a/doc/sphinx/nbackend/rationale.rst
+++ b/doc/sphinx/nbackend/rationale.rst
@@ -27,9 +27,8 @@ The primary problems caused by using SymPy for expression manipulation are these
    and parenthesize operations for numerical or performance benefits. Another often-observed effect is that
    SymPy distributes constant factors across sums, strongly increasing the number of FLOPs.
 
-To avoid these problems, ``nbackend`` uses the [pymbolic](https://pypi.org/project/pymbolic/) package for expression
-manipulation. Pymblic has similar capabilities for writing mathematic expressions as SymPy, however its expression
-trees are much simpler, completely static, and easier to extend.
+To avoid these problems, ``nbackend`` no longer uses SymPy for expression manipulation, but contains a native
+AST data structure for modelling expressions as in C code.
 
 Structure and Architecture of the Code Generator
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/mypy.ini b/mypy.ini
index 96d343aae..75e6a5646 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -2,11 +2,14 @@
 python_version = 3.10
 exclude = "src/pystencils/old"
 
-[mypy-pymbolic.*]
-ignore_missing_imports=true
-
 [mypy-pystencils.*]
 ignore_errors=true
 
-[mypy-pystencils.nbackend.*]
+[mypy-setuptools.*]
+ignore_missing_imports=true
+
+[mypy-appdirs.*]
+ignore_missing_imports=true
+
+[mypy-pystencils.backend.*]
 ignore_errors = False
diff --git a/pyproject.toml b/pyproject.toml
index 87f425ca8..0ccdca068 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@ authors = [
 ]
 license = { file = "COPYING.txt" }
 requires-python = ">=3.10"
-dependencies = ["sympy>=1.6,<=1.11.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml", "pymbolic>=2022.2"]
+dependencies = ["sympy>=1.6,<=1.11.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
 classifiers = [
     "Development Status :: 4 - Beta",
     "Framework :: Jupyter",
diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index a01bad607..59d6a32e5 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -1,5 +1,7 @@
 """Module to generate stencil kernels in C or CUDA using sympy expressions and call them as Python functions"""
+
 from .enums import Backend, Target
+from .defaults import DEFAULTS
 from . import fd
 from . import stencil as stencil
 from .display_utils import get_code_obj, get_code_str, show_code, to_dot
@@ -9,34 +11,59 @@ from .config import CreateKernelConfig
 from .kernel_decorator import kernel, kernel_config
 from .kernelcreation import create_kernel
 from .slicing import make_slice
-from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered
+from .spatial_coordinates import (
+    x_,
+    x_staggered,
+    x_staggered_vector,
+    x_vector,
+    y_,
+    y_staggered,
+    z_,
+    z_staggered,
+)
 from .sympyextensions import Assignment, AssignmentCollection, AddAugmentedAssignment
 from .sympyextensions.astnodes import assignment_from_stencil
 from .sympyextensions.typed_sympy import TypedSymbol
 from .sympyextensions.math import SymbolCreator
 from .datahandling import create_data_handling
 
-__all__ = ['Field', 'FieldType', 'fields',
-           'TypedSymbol',
-           'make_slice',
-           'CreateKernelConfig',
-           'create_kernel',
-           'Target', 'Backend',
-           'show_code', 'to_dot', 'get_code_obj', 'get_code_str',
-           'AssignmentCollection',
-           'Assignment', 'AddAugmentedAssignment',
-           'assignment_from_stencil',
-           'SymbolCreator',
-           'create_data_handling',
-           'clear_cache',
-           'kernel', 'kernel_config',
-           'x_', 'y_', 'z_',
-           'x_staggered', 'y_staggered', 'z_staggered',
-           'x_vector', 'x_staggered_vector',
-           'fd',
-           'stencil']
+__all__ = [
+    "Field",
+    "FieldType",
+    "fields",
+    "DEFAULTS",
+    "TypedSymbol",
+    "make_slice",
+    "CreateKernelConfig",
+    "create_kernel",
+    "Target",
+    "Backend",
+    "show_code",
+    "to_dot",
+    "get_code_obj",
+    "get_code_str",
+    "AssignmentCollection",
+    "Assignment",
+    "AddAugmentedAssignment",
+    "assignment_from_stencil",
+    "SymbolCreator",
+    "create_data_handling",
+    "clear_cache",
+    "kernel",
+    "kernel_config",
+    "x_",
+    "y_",
+    "z_",
+    "x_staggered",
+    "y_staggered",
+    "z_staggered",
+    "x_vector",
+    "x_staggered_vector",
+    "fd",
+    "stencil",
+]
 
 from ._version import get_versions
 
-__version__ = get_versions()['version']
+__version__ = get_versions()["version"]
 del get_versions
diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py
index 4cc3ad037..8f36bc76d 100644
--- a/src/pystencils/backend/arrays.py
+++ b/src/pystencils/backend/arrays.py
@@ -36,27 +36,22 @@ all occurences of the shape and stride variables with their constant value::
 """
 
 from __future__ import annotations
-from sys import intern
 
 from typing import Sequence
 from types import EllipsisType
 
 from abc import ABC
 
-import pymbolic.primitives as pb
-
+from .constants import PsConstant
 from .types import (
     PsAbstractType,
     PsPointerType,
     PsIntegerType,
     PsUnsignedIntegerType,
     PsSignedIntegerType,
-    PsScalarType,
-    PsVectorType,
-    PsTypeError,
 )
 
-from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
+from .symbols import PsSymbol
 
 
 class PsLinearizedArray:
@@ -91,20 +86,20 @@ class PsLinearizedArray:
         if len(shape) != len(strides):
             raise ValueError("Shape and stride tuples must have the same length")
 
-        self._shape: tuple[PsArrayShapeVar | PsTypedConstant, ...] = tuple(
+        self._shape: tuple[PsArrayShapeVar | PsConstant, ...] = tuple(
             (
                 PsArrayShapeVar(self, i, index_dtype)
                 if s == Ellipsis
-                else PsTypedConstant(s, index_dtype)
+                else PsConstant(s, index_dtype)
             )
             for i, s in enumerate(shape)
         )
 
-        self._strides: tuple[PsArrayStrideVar | PsTypedConstant, ...] = tuple(
+        self._strides: tuple[PsArrayStrideVar | PsConstant, ...] = tuple(
             (
                 PsArrayStrideVar(self, i, index_dtype)
                 if s == Ellipsis
-                else PsTypedConstant(s, index_dtype)
+                else PsConstant(s, index_dtype)
             )
             for i, s in enumerate(strides)
         )
@@ -122,7 +117,7 @@ class PsLinearizedArray:
         return self._base_ptr
 
     @property
-    def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]:
+    def shape(self) -> tuple[PsArrayShapeVar | PsConstant, ...]:
         """The array's shape, expressed using `PsTypedConstant` and `PsArrayShapeVar`"""
         return self._shape
 
@@ -130,11 +125,11 @@ class PsLinearizedArray:
     def shape_spec(self) -> tuple[EllipsisType | int, ...]:
         """The array's shape, expressed using `int` and `...`"""
         return tuple(
-            (s.value if isinstance(s, PsTypedConstant) else ...) for s in self._shape
+            (s.value if isinstance(s, PsConstant) else ...) for s in self._shape
         )
 
     @property
-    def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]:
+    def strides(self) -> tuple[PsArrayStrideVar | PsConstant, ...]:
         """The array's strides, expressed using `PsTypedConstant` and `PsArrayStrideVar`"""
         return self._strides
 
@@ -142,7 +137,7 @@ class PsLinearizedArray:
     def strides_spec(self) -> tuple[EllipsisType | int, ...]:
         """The array's strides, expressed using `int` and `...`"""
         return tuple(
-            (s.value if isinstance(s, PsTypedConstant) else ...) for s in self._strides
+            (s.value if isinstance(s, PsConstant) else ...) for s in self._strides
         )
 
     @property
@@ -181,7 +176,7 @@ class PsLinearizedArray:
         )
 
 
-class PsArrayAssocVar(PsTypedVariable, ABC):
+class PsArrayAssocSymbol(PsSymbol, ABC):
     """A variable that is associated to an array.
 
     Instances of this class represent pointers and indexing information bound
@@ -203,7 +198,7 @@ class PsArrayAssocVar(PsTypedVariable, ABC):
         return self._array
 
 
-class PsArrayBasePointer(PsArrayAssocVar):
+class PsArrayBasePointer(PsArrayAssocSymbol):
     init_arg_names: tuple[str, ...] = ("name", "array")
     __match_args__ = ("name", "array")
 
@@ -229,7 +224,7 @@ class TypeErasedBasePointer(PsArrayBasePointer):
         self._array = array
 
 
-class PsArrayShapeVar(PsArrayAssocVar):
+class PsArrayShapeVar(PsArrayAssocSymbol):
     """Variable that represents an array's shape in one coordinate.
 
     Do not instantiate this class yourself, but only use its instances
@@ -252,7 +247,7 @@ class PsArrayShapeVar(PsArrayAssocVar):
         return self.array, self.coordinate, self.dtype
 
 
-class PsArrayStrideVar(PsArrayAssocVar):
+class PsArrayStrideVar(PsArrayAssocSymbol):
     """Variable that represents an array's stride in one coordinate.
 
     Do not instantiate this class yourself, but only use its instances
@@ -273,92 +268,3 @@ class PsArrayStrideVar(PsArrayAssocVar):
 
     def __getinitargs__(self):
         return self.array, self.coordinate, self.dtype
-
-
-class PsArrayAccess(pb.Subscript):
-    mapper_method = intern("map_array_access")
-
-    def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant):
-        super(PsArrayAccess, self).__init__(base_ptr, index)
-        self._base_ptr = base_ptr
-        self._index = index
-
-    @property
-    def base_ptr(self):
-        return self._base_ptr
-
-    @property
-    def array(self) -> PsLinearizedArray:
-        return self._base_ptr.array
-
-    @property
-    def dtype(self) -> PsAbstractType:
-        """Data type of this expression, i.e. the element type of the underlying array"""
-        return self._base_ptr.array.element_type
-
-
-class PsVectorArrayAccess(pb.AlgebraicLeaf):
-    mapper_method = intern("map_vector_array_access")
-
-    init_arg_names = ("base_ptr", "base_index", "vector_entries", "stride", "alignment")
-
-    def __getinitargs__(self):
-        return (
-            self._base_ptr,
-            self._base_index,
-            self._vector_type.vector_entries,
-            self._stride,
-            self._alignment,
-        )
-
-    def __init__(
-        self,
-        base_ptr: PsArrayBasePointer,
-        base_index: ExprOrConstant,
-        vector_entries: int,
-        stride: int = 1,
-        alignment: int = 0,
-    ):
-        element_type = base_ptr.array.element_type
-
-        if not isinstance(element_type, PsScalarType):
-            raise PsTypeError(
-                "Cannot generate vector accesses to arrays with non-scalar elements"
-            )
-
-        self._base_ptr = base_ptr
-        self._base_index = base_index
-        self._vector_type = PsVectorType(
-            element_type, vector_entries, const=element_type.const
-        )
-        self._stride = stride
-        self._alignment = alignment
-
-    @property
-    def base_ptr(self) -> PsArrayBasePointer:
-        return self._base_ptr
-
-    @property
-    def array(self) -> PsLinearizedArray:
-        return self._base_ptr.array
-
-    @property
-    def base_index(self) -> ExprOrConstant:
-        return self._base_index
-
-    @property
-    def vector_entries(self) -> int:
-        return self._vector_type.vector_entries
-
-    @property
-    def dtype(self) -> PsVectorType:
-        """Data type of this expression, i.e. the resulting generic vector type"""
-        return self._vector_type
-
-    @property
-    def stride(self) -> int:
-        return self._stride
-
-    @property
-    def alignment(self) -> int:
-        return self._alignment
diff --git a/src/pystencils/backend/ast/__init__.py b/src/pystencils/backend/ast/__init__.py
index 7c372d0f4..2f25c3943 100644
--- a/src/pystencils/backend/ast/__init__.py
+++ b/src/pystencils/backend/ast/__init__.py
@@ -1,35 +1,9 @@
-from .nodes import (
-    PsAstNode,
-    PsBlock,
-    PsExpression,
-    PsLvalueExpr,
-    PsSymbolExpr,
-    PsStatement,
-    PsAssignment,
-    PsDeclaration,
-    PsLoop,
-    PsConditional,
-    PsComment,
-)
 from .kernelfunction import PsKernelFunction
 
-from .tree_iteration import dfs_preorder, dfs_postorder
-from .dispatcher import ast_visitor
+from .iteration import dfs_preorder, dfs_postorder
 
 __all__ = [
-    "ast_visitor",
     "PsKernelFunction",
-    "PsAstNode",
-    "PsBlock",
-    "PsExpression",
-    "PsLvalueExpr",
-    "PsSymbolExpr",
-    "PsStatement",
-    "PsAssignment",
-    "PsDeclaration",
-    "PsLoop",
-    "PsConditional",
-    "PsComment",
     "dfs_preorder",
     "dfs_postorder",
 ]
diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py
new file mode 100644
index 000000000..718a9397e
--- /dev/null
+++ b/src/pystencils/backend/ast/analysis.py
@@ -0,0 +1,105 @@
+from typing import cast
+from functools import reduce
+
+from .kernelfunction import PsKernelFunction
+from .structural import (
+    PsAstNode,
+    PsExpression,
+    PsStatement,
+    PsAssignment,
+    PsDeclaration,
+    PsLoop,
+    PsBlock,
+)
+from .expressions import PsSymbolExpr, PsConstantExpr
+
+from ..symbols import PsSymbol
+from ..exceptions import PsInternalCompilerError
+
+
+class UndefinedSymbolsCollector:
+    """Collector for undefined variables.
+
+    This class implements an AST visitor that collects all `PsTypedVariable`s that have been used
+    in the AST without being defined prior to their usage.
+    """
+
+    def __call__(self, node: PsAstNode) -> set[PsSymbol]:
+        """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
+        return self.visit(node)
+
+    def visit(self, node: PsAstNode) -> set[PsSymbol]:
+        undefined_vars: set[PsSymbol] = set()
+
+        match node:
+            case PsKernelFunction(block):
+                return self.visit(block)
+
+            case PsExpression():
+                return self.visit_expr(node)
+
+            case PsStatement(expr):
+                return self.visit_expr(expr)
+
+            case PsAssignment(lhs, rhs):
+                undefined_vars = self(lhs) | self(rhs)
+                if isinstance(lhs, PsSymbolExpr):
+                    undefined_vars.remove(lhs.symbol)
+                return undefined_vars
+
+            case PsBlock(statements):
+                for stmt in statements[::-1]:
+                    undefined_vars -= self.declared_variables(stmt)
+                    undefined_vars |= self(stmt)
+
+                return undefined_vars
+
+            case PsLoop(ctr, start, stop, step, body):
+                undefined_vars = self(start) | self(stop) | self(step) | self(body)
+                undefined_vars.discard(ctr.symbol)
+                return undefined_vars
+
+            case unknown:
+                raise PsInternalCompilerError(
+                    f"Don't know how to collect undefined variables from {unknown}"
+                )
+
+    def visit_expr(self, expr: PsExpression) -> set[PsSymbol]:
+        match expr:
+            case PsSymbolExpr(symb):
+                return {symb}
+            case _:
+                return reduce(
+                    set.union, (self.visit_expr(cast(PsExpression, c)) for c in expr.children), set()
+                )
+
+    def declared_variables(self, node: PsAstNode) -> set[PsSymbol]:
+        """Returns the set of variables declared by the given node which are visible in the enclosing scope."""
+
+        match node:
+            case PsDeclaration(lhs, _):
+                return {lhs.symbol}
+
+            case PsStatement() | PsAssignment() | PsExpression() | PsLoop() | PsBlock():
+                return set()
+
+            case unknown:
+                raise PsInternalCompilerError(
+                    f"Don't know how to collect declared variables from {unknown}"
+                )
+
+
+def collect_undefined_variables(node: PsAstNode) -> set[PsSymbol]:
+    return UndefinedSymbolsCollector()(node)
+
+
+def collect_required_headers(node: PsAstNode) -> set[str]:
+    match node:
+        case PsSymbolExpr(symb):
+            return symb.get_dtype().required_headers
+        case PsConstantExpr(cs):
+            return cs.get_dtype().required_headers
+        case _:
+            return reduce(
+                set.union, (collect_required_headers(c) for c in node.children), set()
+            )
diff --git a/src/pystencils/backend/ast/astnode.py b/src/pystencils/backend/ast/astnode.py
new file mode 100644
index 000000000..6680c58e5
--- /dev/null
+++ b/src/pystencils/backend/ast/astnode.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+from typing import Sequence
+from abc import ABC, abstractmethod
+
+
+class PsAstNode(ABC):
+    """Base class for all nodes in the pystencils AST.
+
+    This base class provides a common interface to inspect and update the AST's branching structure.
+    The two methods `get_children` and `set_child` must be implemented by each subclass.
+    Subclasses are also responsible for doing the necessary type checks if they place restrictions on
+    the types of their children.
+    """
+
+    @property
+    def children(self) -> Sequence[PsAstNode]:
+        return self.get_children()
+
+    @children.setter
+    def children(self, cs: Sequence[PsAstNode]):
+        for i, c in enumerate(cs):
+            self.set_child(i, c)
+
+    @abstractmethod
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        pass
+
+    @abstractmethod
+    def set_child(self, idx: int, c: PsAstNode):
+        pass
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        """Check two ASTs for structural equality."""
+        return (
+            (type(self) is type(other))
+            and len(self.children) == len(other.children)
+            and all(
+                c1.structurally_equal(c2)
+                for c1, c2 in zip(self.children, other.children)
+            )
+        )
+
+
+class PsLeafMixIn(ABC):
+    """Mix-in for AST leaves."""
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return ()
+
+    def set_child(self, idx: int, c: PsAstNode):
+        raise IndexError("Child index out of bounds: Leaf nodes have no children.")
+
+    @abstractmethod
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        pass
diff --git a/src/pystencils/backend/ast/collectors.py b/src/pystencils/backend/ast/collectors.py
deleted file mode 100644
index e2488f95b..000000000
--- a/src/pystencils/backend/ast/collectors.py
+++ /dev/null
@@ -1,134 +0,0 @@
-from typing import cast, Any
-
-from functools import reduce
-
-from pymbolic.primitives import Variable
-from pymbolic.mapper import Collector
-from pymbolic.mapper.dependency import DependencyMapper
-
-from .kernelfunction import PsKernelFunction
-from .nodes import (
-    PsAstNode,
-    PsExpression,
-    PsStatement,
-    PsAssignment,
-    PsDeclaration,
-    PsLoop,
-    PsBlock,
-)
-from ..arrays import PsVectorArrayAccess
-from ..typed_expressions import PsTypedVariable, PsTypedConstant
-from ..exceptions import PsMalformedAstException, PsInternalCompilerError
-
-
-class UndefinedVariablesCollector(DependencyMapper):
-    """Collector for undefined variables.
-
-    This class implements an AST visitor that collects all `PsTypedVariable`s that have been used
-    in the AST without being defined prior to their usage.
-    """
-
-    def __init__(self) -> None:
-        super().__init__(
-            include_subscripts=False,
-            include_lookups=False,
-            include_calls=False,
-            include_cses=False,
-        )
-
-    def __call__(self, node: PsAstNode) -> set[PsTypedVariable]:
-        """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
-
-        undefined_vars: set[PsTypedVariable] = set()
-
-        match node:
-            case PsKernelFunction(block):
-                return self(block)
-
-            case PsExpression(expr):
-                variables: set[Variable] = self.rec(expr)
-
-                for var in variables:
-                    if not isinstance(var, PsTypedVariable):
-                        raise PsMalformedAstException(
-                            f"Non-typed variable {var} encountered"
-                        )
-
-                return cast(set[PsTypedVariable], variables)
-
-            case PsStatement(expr):
-                return self(expr)
-
-            case PsAssignment(lhs, rhs):
-                undefined_vars = self(lhs) | self(rhs)
-                if isinstance(lhs.expression, PsTypedVariable):
-                    undefined_vars.remove(lhs.expression)
-                return undefined_vars
-
-            case PsBlock(statements):
-                for stmt in statements[::-1]:
-                    undefined_vars -= self.declared_variables(stmt)
-                    undefined_vars |= self(stmt)
-
-                return undefined_vars
-
-            case PsLoop(ctr, start, stop, step, body):
-                undefined_vars = self(start) | self(stop) | self(step) | self(body)
-                undefined_vars.discard(ctr)
-                return undefined_vars
-
-            case unknown:
-                raise PsInternalCompilerError(
-                    f"Don't know how to collect undefined variables from {unknown}"
-                )
-
-    def declared_variables(self, node: PsAstNode) -> set[PsTypedVariable]:
-        """Returns the set of variables declared by the given node which are visible in the enclosing scope."""
-
-        match node:
-            case PsDeclaration(lhs, _):
-                return {lhs.symbol}
-
-            case PsStatement() | PsAssignment() | PsExpression() | PsLoop() | PsBlock():
-                return set()
-
-            case unknown:
-                raise PsInternalCompilerError(
-                    f"Don't know how to collect declared variables from {unknown}"
-                )
-
-    def map_vector_array_access(
-        self, vacc: PsVectorArrayAccess
-    ) -> set[PsTypedVariable]:
-        return {vacc.base_ptr} | self.rec(vacc.base_index)
-
-
-def collect_undefined_variables(node: PsAstNode) -> set[PsTypedVariable]:
-    return UndefinedVariablesCollector()(node)
-
-
-class RequiredHeadersCollector(Collector):
-    """Collect all header files required by a given AST for correct compilation.
-
-    Required headers can currently only be defined in subclasses of `PsAbstractType`.
-    """
-
-    def __call__(self, node: PsAstNode) -> set[str]:
-        match node:
-            case PsExpression(expr):
-                return self.rec(expr)
-            case node:
-                return reduce(set.union, (self(c) for c in node.children), set())
-
-    def map_typed_variable(self, var: PsTypedVariable) -> set[str]:
-        return var.dtype.required_headers
-
-    def map_constant(self, cst: Any):
-        if not isinstance(cst, PsTypedConstant):
-            raise PsMalformedAstException("Untyped constant encountered in expression.")
-
-        return cst.dtype.required_headers
-
-
-def collect_required_headers(node: PsAstNode) -> set[str]:
-    return RequiredHeadersCollector()(node)
diff --git a/src/pystencils/backend/ast/dispatcher.py b/src/pystencils/backend/ast/dispatcher.py
deleted file mode 100644
index f5a23b4c7..000000000
--- a/src/pystencils/backend/ast/dispatcher.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from __future__ import annotations
-
-from functools import wraps
-
-from typing import Callable
-from types import MethodType
-
-from .nodes import PsAstNode
-
-
-class VisitorDispatcher:
-    def __init__(self, wrapped_method):
-        self._dispatch_dict = {}
-        self._wrapped_method = wrapped_method
-
-    def case(self, node_type: type):
-        """Decorator for visitor's methods"""
-
-        def decorate(handler: Callable):
-            if node_type in self._dispatch_dict:
-                raise ValueError(f"Duplicate visitor case {node_type}")
-            self._dispatch_dict[node_type] = handler
-            return handler
-
-        return decorate
-
-    def __call__(self, instance, node: PsAstNode, *args, **kwargs):
-        for cls in node.__class__.mro():
-            if cls in self._dispatch_dict:
-                return self._dispatch_dict[cls](instance, node, *args, **kwargs)
-
-        return self._wrapped_method(instance, node, *args, **kwargs)
-
-    def __get__(self, obj, objtype=None):
-        if obj is None:
-            return self
-        return MethodType(self, obj)
-
-
-def ast_visitor(method):
-    return wraps(method)(VisitorDispatcher(method))
diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
new file mode 100644
index 000000000..a4b784dc4
--- /dev/null
+++ b/src/pystencils/backend/ast/expressions.py
@@ -0,0 +1,423 @@
+from __future__ import annotations
+from abc import ABC
+from typing import Sequence, overload
+
+from ..symbols import PsSymbol
+from ..constants import PsConstant
+from ..arrays import PsLinearizedArray, PsArrayBasePointer
+from ..functions import PsFunction
+from ..types import (
+    PsAbstractType,
+    PsScalarType,
+    PsVectorType,
+    PsTypeError,
+)
+from .util import failing_cast
+
+from .astnode import PsAstNode, PsLeafMixIn
+
+
+class PsExpression(PsAstNode, ABC):
+    """Base class for all expressions."""
+
+    def __add__(self, other: PsExpression) -> PsAdd:
+        return PsAdd(self, other)
+
+    def __sub__(self, other: PsExpression) -> PsSub:
+        return PsSub(self, other)
+
+    def __mul__(self, other: PsExpression) -> PsMul:
+        return PsMul(self, other)
+
+    def __truediv__(self, other: PsExpression) -> PsDiv:
+        return PsDiv(self, other)
+
+    def __neg__(self) -> PsNeg:
+        return PsNeg(self)
+
+    @overload
+    @staticmethod
+    def make(obj: PsSymbol) -> PsSymbolExpr:
+        pass
+
+    @overload
+    @staticmethod
+    def make(obj: PsConstant) -> PsConstantExpr:
+        pass
+
+    @staticmethod
+    def make(obj: PsSymbol | PsConstant) -> PsSymbolExpr | PsConstantExpr:
+        if isinstance(obj, PsSymbol):
+            return PsSymbolExpr(obj)
+        elif isinstance(obj, PsConstant):
+            return PsConstantExpr(obj)
+        else:
+            raise ValueError(f"Cannot make expression out of {obj}")
+
+
+class PsLvalueExpr(PsExpression, ABC):
+    """Base class for all expressions that may occur as an lvalue"""
+
+
+class PsSymbolExpr(PsLeafMixIn, PsLvalueExpr):
+    """A single symbol as an expression."""
+
+    __match_args__ = ("symbol",)
+
+    def __init__(self, symbol: PsSymbol):
+        self._symbol = symbol
+
+    @property
+    def symbol(self) -> PsSymbol:
+        return self._symbol
+
+    @symbol.setter
+    def symbol(self, symbol: PsSymbol):
+        self._symbol = symbol
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsSymbolExpr):
+            return False
+
+        return self._symbol == other._symbol
+
+    def __repr__(self) -> str:
+        return f"Symbol({repr(self._symbol)})"
+
+
+class PsConstantExpr(PsLeafMixIn, PsExpression):
+    __match_args__ = ("constant",)
+
+    def __init__(self, constant: PsConstant):
+        self._constant = constant
+
+    @property
+    def constant(self) -> PsConstant:
+        return self._constant
+
+    @constant.setter
+    def constant(self, c: PsConstant):
+        self._constant = c
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsConstantExpr):
+            return False
+
+        return self._constant == other._constant
+
+    def __repr__(self) -> str:
+        return f"Constant({repr(self._constant)})"
+
+
+class PsSubscript(PsLvalueExpr):
+    __match_args__ = ("base", "index")
+
+    def __init__(self, base: PsExpression, index: PsExpression):
+        self._base = base
+        self._index = index
+
+    @property
+    def base(self) -> PsExpression:
+        return self._base
+
+    @base.setter
+    def base(self, expr: PsExpression):
+        self._base = expr
+
+    @property
+    def index(self) -> PsExpression:
+        return self._index
+
+    @index.setter
+    def index(self, expr: PsExpression):
+        self._index = expr
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._base, self._index)
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = [0, 1][idx]
+        match idx:
+            case 0:
+                self.base = failing_cast(PsExpression, c)
+            case 1:
+                self.index = failing_cast(PsExpression, c)
+
+
+class PsArrayAccess(PsSubscript):
+    __match_args__ = ("base_ptr", "index")
+
+    def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression):
+        super().__init__(PsExpression.make(base_ptr), index)
+        self._base_ptr = base_ptr
+
+    @property
+    def base_ptr(self) -> PsArrayBasePointer:
+        return self._base_ptr
+
+    @property
+    def base(self) -> PsExpression:
+        return self._base
+
+    @base.setter
+    def base(self, expr: PsExpression):
+        if not isinstance(expr, PsSymbolExpr) or not isinstance(
+            expr.symbol, PsArrayBasePointer
+        ):
+            raise ValueError(
+                "Base expression of PsArrayAccess must be an array base pointer"
+            )
+
+        self._base_ptr = expr.symbol
+        self._base = expr
+
+    @property
+    def array(self) -> PsLinearizedArray:
+        return self._base_ptr.array
+
+    @property
+    def dtype(self) -> PsAbstractType:
+        """Data type of this expression, i.e. the element type of the underlying array"""
+        return self._base_ptr.array.element_type
+
+    def __repr__(self) -> str:
+        return f"ArrayAccess({repr(self._base_ptr)}, {repr(self._index)})"
+
+
+class PsVectorArrayAccess(PsArrayAccess):
+    __match_args__ = ("base_ptr", "base_index")
+
+    def __init__(
+        self,
+        base_ptr: PsArrayBasePointer,
+        base_index: PsExpression,
+        vector_entries: int,
+        stride: int = 1,
+        alignment: int = 0,
+    ):
+        super().__init__(base_ptr, base_index)
+        element_type = base_ptr.array.element_type
+
+        if not isinstance(element_type, PsScalarType):
+            raise PsTypeError(
+                "Cannot generate vector accesses to arrays with non-scalar elements"
+            )
+
+        self._vector_type = PsVectorType(
+            element_type, vector_entries, const=element_type.const
+        )
+        self._stride = stride
+        self._alignment = alignment
+
+    @property
+    def vector_entries(self) -> int:
+        return self._vector_type.vector_entries
+
+    @property
+    def dtype(self) -> PsVectorType:
+        """Data type of this expression, i.e. the resulting generic vector type"""
+        return self._vector_type
+
+    @property
+    def stride(self) -> int:
+        return self._stride
+
+    @property
+    def alignment(self) -> int:
+        return self._alignment
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsVectorArrayAccess):
+            return False
+
+        return (
+            super().structurally_equal(other)
+            and self._vector_type == other._vector_type
+            and self._stride == other._stride
+            and self._alignment == other._alignment
+        )
+
+
+class PsLookup(PsExpression):
+    __match_args__ = ("aggregate", "member_name")
+
+    def __init__(self, aggregate: PsExpression, member_name: str) -> None:
+        self._aggregate = aggregate
+        self._member = member_name
+
+    @property
+    def aggregate(self) -> PsExpression:
+        return self._aggregate
+
+    @aggregate.setter
+    def aggregate(self, aggr: PsExpression):
+        self._aggregate = aggr
+
+    @property
+    def member_name(self) -> str:
+        return self._member
+
+    @member_name.setter
+    def member_name(self, name: str):
+        self._name = name
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._aggregate,)
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = [0][idx]
+        self._aggregate = failing_cast(PsExpression, c)
+
+
+class PsCall(PsExpression):
+    __match_args__ = ("function", "args")
+
+    def __init__(self, function: PsFunction, args: Sequence[PsExpression]) -> None:
+        if len(args) != function.arg_count:
+            raise ValueError(
+                f"Argument count mismatch: Cannot apply function {function} to {len(args)} arguments."
+            )
+
+        self._function = function
+        self._args = list(args)
+
+    @property
+    def function(self) -> PsFunction:
+        return self._function
+
+    @property
+    def args(self) -> tuple[PsExpression, ...]:
+        return tuple(self._args)
+
+    @args.setter
+    def args(self, exprs: Sequence[PsExpression]):
+        if len(exprs) != self._function.arg_count:
+            raise ValueError(
+                f"Argument count mismatch: Cannot apply function {self._function} to {len(exprs)} arguments."
+            )
+
+        self._args = list(exprs)
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return self.args
+
+    def set_child(self, idx: int, c: PsAstNode):
+        self._args[idx] = failing_cast(PsExpression, c)
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsCall):
+            return False
+        return super().structurally_equal(other) and self._function == other._function
+
+
+class PsUnOp(PsExpression):
+    __match_args__ = ("operand",)
+
+    def __init__(self, operand: PsExpression):
+        self._operand = operand
+
+    @property
+    def operand(self) -> PsExpression:
+        return self._operand
+
+    @operand.setter
+    def operand(self, expr: PsExpression):
+        self._operand = expr
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._operand,)
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = [0][idx]
+        self._operand = failing_cast(PsExpression, c)
+
+
+class PsNeg(PsUnOp):
+    pass
+
+
+class PsDeref(PsUnOp):
+    pass
+
+
+class PsAddressOf(PsUnOp):
+    pass
+
+
+class PsCast(PsUnOp):
+    __match_args__ = ("target_type", "operand")
+
+    def __init__(self, target_type: PsAbstractType, operand: PsExpression):
+        super().__init__(operand)
+        self._target_type = target_type
+
+    @property
+    def target_type(self) -> PsAbstractType:
+        return self._target_type
+
+    @target_type.setter
+    def target_type(self, dtype: PsAbstractType):
+        self._target_type = dtype
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsCast):
+            return False
+        return (
+            super().structurally_equal(other)
+            and self._target_type == other._target_type
+        )
+
+
+class PsBinOp(PsExpression):
+    __match_args__ = ("operand1", "operand2")
+
+    def __init__(self, op1: PsExpression, op2: PsExpression):
+        self._op1 = op1
+        self._op2 = op2
+
+    @property
+    def operand1(self) -> PsExpression:
+        return self._op1
+
+    @operand1.setter
+    def operand1(self, expr: PsExpression):
+        self._op1 = expr
+
+    @property
+    def operand2(self) -> PsExpression:
+        return self._op2
+
+    @operand2.setter
+    def operand2(self, expr: PsExpression):
+        self._op2 = expr
+
+    def get_children(self) -> tuple[PsAstNode, ...]:
+        return (self._op1, self._op2)
+
+    def set_child(self, idx: int, c: PsAstNode):
+        idx = [0, 1][idx]
+        match idx:
+            case 0:
+                self._op1 = failing_cast(PsExpression, c)
+            case 1:
+                self._op2 = failing_cast(PsExpression, c)
+
+    def __repr__(self) -> str:
+        opname = self.__class__.__name__
+        return f"{opname}({repr(self._op1)}, {repr(self._op2)})"
+
+
+class PsAdd(PsBinOp):
+    pass
+
+
+class PsSub(PsBinOp):
+    pass
+
+
+class PsMul(PsBinOp):
+    pass
+
+
+class PsDiv(PsBinOp):
+    pass
diff --git a/src/pystencils/backend/ast/tree_iteration.py b/src/pystencils/backend/ast/iteration.py
similarity index 96%
rename from src/pystencils/backend/ast/tree_iteration.py
rename to src/pystencils/backend/ast/iteration.py
index 1549d7580..6c1c406ed 100644
--- a/src/pystencils/backend/ast/tree_iteration.py
+++ b/src/pystencils/backend/ast/iteration.py
@@ -1,6 +1,6 @@
 from typing import Callable, Generator
 
-from .nodes import PsAstNode
+from .structural import PsAstNode
 
 
 def dfs_preorder(
diff --git a/src/pystencils/backend/ast/kernelfunction.py b/src/pystencils/backend/ast/kernelfunction.py
index deca2cf18..2a7997ff5 100644
--- a/src/pystencils/backend/ast/kernelfunction.py
+++ b/src/pystencils/backend/ast/kernelfunction.py
@@ -3,13 +3,11 @@ from __future__ import annotations
 from typing import Callable
 from dataclasses import dataclass
 
-from pymbolic.mapper.dependency import DependencyMapper
+from .structural import PsAstNode, PsBlock, failing_cast
 
-from .nodes import PsAstNode, PsBlock, failing_cast
-
-from ..constraints import PsKernelConstraint
-from ..typed_expressions import PsTypedVariable
-from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
+from ..symbols import PsSymbol
+from ..constraints import PsKernelParamsConstraint
+from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocSymbol
 from ..jit import JitBase, no_jit
 from ..exceptions import PsInternalCompilerError
 
@@ -21,40 +19,38 @@ class PsKernelParametersSpec:
     """Specification of a kernel function's parameters.
 
     Contains:
-        - Verbatim parameter list, a list of `PsTypedVariables`
+        - Verbatim parameter list, a list of `PsSymbol`s
         - List of Arrays used in the kernel, in canonical order
         - A set of constraints on the kernel parameters, used to e.g. express relations of array
           shapes, alignment properties, ...
     """
 
-    params: tuple[PsTypedVariable, ...]
+    params: tuple[PsSymbol, ...]
     arrays: tuple[PsLinearizedArray, ...]
-    constraints: tuple[PsKernelConstraint, ...]
+    constraints: tuple[PsKernelParamsConstraint, ...]
 
     def params_for_array(self, arr: PsLinearizedArray):
-        def pred(p: PsTypedVariable):
-            return isinstance(p, PsArrayAssocVar) and p.array == arr
+        def pred(s: PsSymbol):
+            return isinstance(s, PsArrayAssocSymbol) and s.array == arr
 
         return tuple(filter(pred, self.params))
 
     def __post_init__(self):
-        dep_mapper = DependencyMapper(False, False, False, False)
-
         #   Check constraints
         for constraint in self.constraints:
-            variables: set[PsTypedVariable] = dep_mapper(constraint.condition)
-            for var in variables:
-                if isinstance(var, PsArrayAssocVar):
-                    if var.array in self.arrays:
+            symbols = constraint.get_symbols()
+            for sym in symbols:
+                if isinstance(sym, PsArrayAssocSymbol):
+                    if sym.array in self.arrays:
                         continue
 
-                elif var in self.params:
+                elif sym in self.params:
                     continue
 
                 raise PsInternalCompilerError(
                     "Constrained parameter was neither contained in kernel parameter list "
                     "nor associated with a kernel array.\n"
-                    f"    Parameter: {var}\n"
+                    f"    Parameter: {sym}\n"
                     f"    Constraint: {constraint.condition}"
                 )
 
@@ -82,7 +78,7 @@ class PsKernelFunction(PsAstNode):
         self._jit = jit
 
         self._required_headers = required_headers
-        self._constraints: list[PsKernelConstraint] = []
+        self._constraints: list[PsKernelParamsConstraint] = []
 
     @property
     def target(self) -> Target:
@@ -127,7 +123,7 @@ class PsKernelFunction(PsAstNode):
             raise IndexError(f"Child index out of bounds: {idx}")
         self._body = failing_cast(PsBlock, c)
 
-    def add_constraints(self, *constraints: PsKernelConstraint):
+    def add_constraints(self, *constraints: PsKernelParamsConstraint):
         self._constraints += constraints
 
     def get_parameters(self) -> PsKernelParametersSpec:
@@ -136,7 +132,7 @@ class PsKernelFunction(PsAstNode):
         This function performs a full traversal of the AST.
         To improve performance, make sure to cache the result if necessary.
         """
-        from .collectors import collect_undefined_variables
+        from .analysis import collect_undefined_variables
 
         params_set = collect_undefined_variables(self)
         params_list = sorted(params_set, key=lambda p: p.name)
diff --git a/src/pystencils/backend/ast/nodes.py b/src/pystencils/backend/ast/structural.py
similarity index 69%
rename from src/pystencils/backend/ast/nodes.py
rename to src/pystencils/backend/ast/structural.py
index f3609ffb7..c9da546a8 100644
--- a/src/pystencils/backend/ast/nodes.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -1,52 +1,13 @@
 from __future__ import annotations
-from typing import Sequence, Iterable, cast, TypeAlias
+from typing import Sequence, cast
 from types import NoneType
 
-from pymbolic.primitives import Variable
+from .astnode import PsAstNode, PsLeafMixIn
+from .expressions import PsExpression, PsLvalueExpr, PsSymbolExpr
 
-from abc import ABC, abstractmethod
-
-from ..typed_expressions import ExprOrConstant
-from ..arrays import PsArrayAccess, PsVectorArrayAccess
 from .util import failing_cast
 
 
-class PsAstNode(ABC):
-    """Base class for all nodes in the pystencils AST.
-
-    This base class provides a common interface to inspect and update the AST's branching structure.
-    The two methods `get_children` and `set_child` must be implemented by each subclass.
-    Subclasses are also responsible for doing the necessary type checks if they place restrictions on
-    the types of their children.
-    """
-
-    @property
-    def children(self) -> tuple[PsAstNode, ...]:
-        return self.get_children()
-
-    @children.setter
-    def children(self, cs: Iterable[PsAstNode]):
-        for i, c in enumerate(cs):
-            self.set_child(i, c)
-
-    @abstractmethod
-    def get_children(self) -> tuple[PsAstNode, ...]:
-        pass
-
-    @abstractmethod
-    def set_child(self, idx: int, c: PsAstNode):
-        pass
-
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsAstNode):
-            return False
-
-        return type(self) is type(other) and self.children == other.children
-
-    def __hash__(self) -> int:
-        return hash((type(self), self.children))
-
-
 class PsBlock(PsAstNode):
     __match_args__ = ("statements",)
 
@@ -72,69 +33,6 @@ class PsBlock(PsAstNode):
         return f"PsBlock( {contents} )"
 
 
-class PsLeafNode(PsAstNode):
-    def get_children(self) -> tuple[PsAstNode, ...]:
-        return ()
-
-    def set_child(self, idx: int, c: PsAstNode):
-        raise IndexError("Child index out of bounds: Leaf nodes have no children.")
-
-
-class PsExpression(PsLeafNode):
-    """Wrapper around pymbolics expressions."""
-
-    __match_args__ = ("expression",)
-
-    def __init__(self, expr: ExprOrConstant):
-        self._expr = expr
-
-    @property
-    def expression(self) -> ExprOrConstant:
-        return self._expr
-
-    @expression.setter
-    def expression(self, expr: ExprOrConstant):
-        self._expr = expr
-
-    def __repr__(self) -> str:
-        return f"Expr({repr(self._expr)})"
-
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsExpression):
-            return False
-        return type(self) is type(other) and self._expr == other._expr
-
-    def __hash__(self) -> int:
-        return hash((type(self), self._expr))
-
-
-class PsLvalueExpr(PsExpression):
-    """Wrapper around pymbolics expressions that may occur at the left-hand side of an assignment"""
-
-    def __init__(self, expr: PsLvalue):
-        if not isinstance(expr, (Variable, PsArrayAccess, PsVectorArrayAccess)):
-            raise TypeError("Expression was not a valid lvalue")
-
-        super(PsLvalueExpr, self).__init__(expr)
-
-
-class PsSymbolExpr(PsLvalueExpr):
-    """Wrapper around PsTypedSymbols"""
-
-    __match_args__ = ("symbol",)
-
-    def __init__(self, symbol: Variable):
-        super().__init__(symbol)
-
-    @property
-    def symbol(self) -> Variable:
-        return cast(Variable, self._expr)
-
-    @symbol.setter
-    def symbol(self, symbol: Variable):
-        self._expr = symbol
-
-
 class PsStatement(PsAstNode):
     __match_args__ = ("expression",)
 
@@ -158,10 +56,6 @@ class PsStatement(PsAstNode):
         self._expression = failing_cast(PsExpression, c)
 
 
-PsLvalue: TypeAlias = Variable | PsArrayAccess | PsVectorArrayAccess
-"""Types of expressions that may occur on the left-hand side of assignments."""
-
-
 class PsAssignment(PsAstNode):
     __match_args__ = (
         "lhs",
@@ -376,7 +270,9 @@ class PsConditional(PsAstNode):
                 assert False, "unreachable code"
 
 
-class PsComment(PsLeafNode):
+class PsComment(PsLeafMixIn, PsAstNode):
+    __match_args__ = ("lines",)
+
     def __init__(self, text: str) -> None:
         self._text = text
         self._lines = tuple(text.splitlines())
@@ -388,3 +284,9 @@ class PsComment(PsLeafNode):
     @property
     def lines(self) -> tuple[str, ...]:
         return self._lines
+
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsComment):
+            return False
+
+        return self._text == other._text
diff --git a/src/pystencils/backend/ast/transformations.py b/src/pystencils/backend/ast/transformations.py
deleted file mode 100644
index dc438e52e..000000000
--- a/src/pystencils/backend/ast/transformations.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from abc import ABC
-
-from typing import Dict
-
-from pymbolic.primitives import Expression
-from pymbolic.mapper.substitutor import CachedSubstitutionMapper
-
-from ..typed_expressions import PsTypedVariable
-from .dispatcher import ast_visitor
-from .nodes import PsAstNode, PsAssignment, PsLoop, PsExpression
-
-
-class PsAstTransformer(ABC):
-    def transform_children(self, node: PsAstNode, *args, **kwargs):
-        node.children = tuple(self.visit(c, *args, **kwargs) for c in node.children)
-
-    @ast_visitor
-    def visit(self, node, *args, **kwargs):
-        self.transform_children(node, *args, **kwargs)
-        return node
-
-
-class PsVariablesSubstitutor(PsAstTransformer):
-    def __init__(self, subs_dict: Dict[PsTypedVariable, Expression]):
-        self._subs_dict = subs_dict
-        self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None))
-
-    def subs(self, node: PsAstNode):
-        return self.visit(node)
-
-    visit = PsAstTransformer.visit
-
-    @visit.case(PsAssignment)
-    def assignment(self, asm: PsAssignment):
-        lhs_expr = asm.lhs.expression
-        if isinstance(lhs_expr, PsTypedVariable) and lhs_expr in self._subs_dict:
-            raise ValueError(
-                f"Cannot substitute symbol {lhs_expr} that occurs on a left-hand side of an assignment."
-            )
-        self.transform_children(asm)
-        return asm
-
-    @visit.case(PsLoop)
-    def loop(self, loop: PsLoop):
-        if loop.counter.expression in self._subs_dict:
-            raise ValueError(
-                f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter."
-            )
-        self.transform_children(loop)
-        return loop
-
-    @visit.case(PsExpression)
-    def expression(self, expr_node: PsExpression):
-        self._mapper(expr_node.expression)
-
-
-def ast_subs(node: PsAstNode, subs_dict: dict):
-    return PsVariablesSubstitutor(subs_dict).subs(node)
diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py
new file mode 100644
index 000000000..2ebe855ae
--- /dev/null
+++ b/src/pystencils/backend/constants.py
@@ -0,0 +1,53 @@
+from typing import Any
+
+from .types import PsNumericType, constify
+from .exceptions import PsInternalCompilerError
+
+
+class PsConstant:
+    __match_args__ = ("value", "dtype")
+
+    def __init__(self, value: Any, dtype: PsNumericType | None = None):
+        self._dtype: PsNumericType | None = None
+        self._value = value
+
+        if dtype is not None:
+            self.apply_dtype(dtype)
+
+    @property
+    def value(self) -> Any:
+        return self._value
+
+    @property
+    def dtype(self) -> PsNumericType | None:
+        return self._dtype
+
+    def get_dtype(self) -> PsNumericType:
+        if self._dtype is None:
+            raise PsInternalCompilerError("Data type of constant was not set.")
+        return self._dtype
+
+    def apply_dtype(self, dtype: PsNumericType):
+        if self._dtype is not None:
+            raise PsInternalCompilerError(
+                "Attempt to apply data type to already typed constant."
+            )
+
+        self._dtype = constify(dtype)
+        self._value = self._dtype.create_constant(self._value)
+
+    def __str__(self) -> str:
+        type_str = "<untyped>" if self._dtype is None else str(self._dtype)
+        return f"{str(self._value)}: {type_str}"
+
+    def __repr__(self) -> str:
+        return str(self)
+
+    def __hash__(self) -> int:
+        return hash((self._dtype, self._value))
+
+    def __eq__(self, other) -> bool:
+        if not isinstance(other, PsConstant):
+            return False
+
+        return (self._value, self._dtype) == (other._value, other._dtype)
diff --git a/src/pystencils/backend/constraints.py b/src/pystencils/backend/constraints.py
index 0225420b4..9e5c82cfd 100644
--- a/src/pystencils/backend/constraints.py
+++ b/src/pystencils/backend/constraints.py
@@ -1,22 +1,19 @@
+from typing import Any
 from dataclasses import dataclass
 
-import pymbolic.primitives as pb
-from pymbolic.mapper.c_code import CCodeMapper
-from pymbolic.mapper.dependency import DependencyMapper
-
-from .typed_expressions import PsTypedVariable
+from .symbols import PsSymbol
 
 
 @dataclass
-class PsKernelConstraint:
-    condition: pb.Comparison
+class PsKernelParamsConstraint:
+    condition: Any  # FIXME Implement conditions
     message: str = ""
 
-    def print_c_condition(self):
-        return CCodeMapper()(self.condition)
+    def to_code(self):
+        raise NotImplementedError()
 
-    def get_variables(self) -> set[PsTypedVariable]:
-        return DependencyMapper(False, False, False, False)(self.condition)
+    def get_symbols(self) -> set[PsSymbol]:
+        raise NotImplementedError()
 
     def __str__(self) -> str:
         return f"{self.message} [{self.condition}]"
diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index 3d1a21279..8de626e1f 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -1,12 +1,9 @@
 from __future__ import annotations
+from enum import Enum
 
-from pymbolic.mapper.c_code import CCodeMapper
-
-from .ast import (
-    ast_visitor,
+from .ast.structural import (
     PsAstNode,
     PsBlock,
-    PsExpression,
     PsStatement,
     PsDeclaration,
     PsAssignment,
@@ -14,129 +11,289 @@ from .ast import (
     PsConditional,
     PsComment,
 )
+
+from .ast.expressions import (
+    PsSymbolExpr,
+    PsConstantExpr,
+    PsSubscript,
+    PsVectorArrayAccess,
+    PsLookup,
+    PsCall,
+    PsBinOp,
+    PsAdd,
+    PsSub,
+    PsMul,
+    PsDiv,
+    PsNeg,
+    PsDeref,
+    PsAddressOf,
+    PsCast,
+)
+
+from .types import PsScalarType
+
 from .ast.kernelfunction import PsKernelFunction
-from .typed_expressions import PsTypedVariable
-from .functions import Deref, AddressOf, Cast, CFunction
+
+
+__all__ = ["emit_code", "CAstPrinter"]
 
 
 def emit_code(kernel: PsKernelFunction):
     printer = CAstPrinter()
-    return printer.print(kernel)
+    return printer(kernel)
+
+
+class EmissionError(Exception):
+    """Indicates a fatal error during code printing"""
+
+
+class LR(Enum):
+    Left = 0
+    Right = 1
+    Middle = 2
+
+
+class Ops(Enum):
+    """Operator precedence and associativity in C/C++.
+
+    See also https://en.cppreference.com/w/cpp/language/operator_precedence
+    """
+
+    Weakest = (0, LR.Middle)
+
+    Add = (1, LR.Left)
+    Sub = (1, LR.Left)
 
+    Mul = (2, LR.Left)
+    Div = (2, LR.Left)
+    Rem = (2, LR.Left)
 
-class CExpressionsPrinter(CCodeMapper):
-    def map_deref(self, deref: Deref, enclosing_prec):
-        return "*"
+    Neg = (3, LR.Right)
+    AddressOf = (3, LR.Right)
+    Deref = (3, LR.Right)
+    Cast = (3, LR.Right)
 
-    def map_address_of(self, addrof: AddressOf, enclosing_prec):
-        return "&"
+    Call = (4, LR.Left)
+    Subscript = (4, LR.Left)
+    Lookup = (4, LR.Left)
 
-    def map_cast(self, cast: Cast, enclosing_prec):
-        return f"({cast.target_type.c_string()})"
+    def __init__(self, pred: int, assoc: LR) -> None:
+        self.precedence = pred
+        self.assoc = assoc
 
-    def map_c_function(self, func: CFunction, enclosing_prec):
-        return func.qualified_name
+
+class PrinterCtx:
+    def __init__(self) -> None:
+        self.operator_stack = [Ops.Weakest]
+        self.branch_stack: list[LR] = []
+        self.indent_level = 0
+
+    def push_op(self, operator: Ops, branch: LR):
+        self.operator_stack.append(operator)
+        self.branch_stack.append(branch)
+
+    def pop_op(self) -> None:
+        self.operator_stack.pop()
+        self.branch_stack.pop()
+
+    def switch_branch(self, branch: LR):
+        self.branch_stack[-1] = branch
+
+    @property
+    def current_op(self) -> Ops:
+        return self.operator_stack[-1]
+
+    @property
+    def current_branch(self) -> LR:
+        return self.branch_stack[-1]
+
+    def parenthesize(self, expr: str, next_operator: Ops) -> str:
+        if next_operator.precedence < self.current_op.precedence:
+            return f"({expr})"
+        elif (
+            next_operator.precedence == self.current_op.precedence
+            and self.current_branch != self.current_op.assoc
+        ):
+            return f"({expr})"
+
+        return expr
+
+    def indent(self, line: str) -> str:
+        return " " * self.indent_level + line
 
 
 class CAstPrinter:
     def __init__(self, indent_width=3):
         self._indent_width = indent_width
 
-        self._current_indent_level = 0
-
-        self._expr_printer = CExpressionsPrinter()
-
-    def indent(self, line):
-        return " " * self._current_indent_level + line
-
-    def print(self, node: PsAstNode) -> str:
-        return self.visit(node)
-
-    @ast_visitor
-    def visit(self, _: PsAstNode) -> str:
-        raise ValueError("Cannot print this node.")
-
-    @visit.case(PsKernelFunction)
-    def function(self, func: PsKernelFunction) -> str:
-        params_spec = func.get_parameters()
-        params_str = ", ".join(
-            f"{p.dtype.c_string()} {p.name}" for p in params_spec.params
-        )
-        decl = f"FUNC_PREFIX void {func.name} ({params_str})"
-        body = self.visit(func.body)
-        return f"{decl}\n{body}"
-
-    @visit.case(PsBlock)
-    def block(self, block: PsBlock):
-        if not block.children:
-            return self.indent("{ }")
-
-        self._current_indent_level += self._indent_width
-        interior = "\n".join(self.visit(c) for c in block.children) + "\n"
-        self._current_indent_level -= self._indent_width
-        return self.indent("{\n") + interior + self.indent("}\n")
-
-    @visit.case(PsExpression)
-    def pymb_expression(self, expr: PsExpression):
-        return self._expr_printer(expr.expression)
-
-    @visit.case(PsStatement)
-    def statement(self, stmt: PsStatement):
-        return self.indent(f"{self.visit(stmt.expression)};")
-
-    @visit.case(PsDeclaration)
-    def declaration(self, decl: PsDeclaration):
-        lhs_symb = decl.declared_variable.symbol
-        assert isinstance(lhs_symb, PsTypedVariable)
-        lhs_dtype = lhs_symb.dtype
-        rhs_code = self.visit(decl.rhs)
-
-        return self.indent(f"{lhs_dtype.c_string()} {lhs_symb.name} = {rhs_code};")
-
-    @visit.case(PsAssignment)
-    def assignment(self, asm: PsAssignment):
-        lhs_code = self.visit(asm.lhs)
-        rhs_code = self.visit(asm.rhs)
-        return self.indent(f"{lhs_code} = {rhs_code};")
-
-    @visit.case(PsLoop)
-    def loop(self, loop: PsLoop):
-        ctr_symbol = loop.counter.symbol
-        assert isinstance(ctr_symbol, PsTypedVariable)
-
-        ctr = ctr_symbol.name
-        start_code = self.visit(loop.start)
-        stop_code = self.visit(loop.stop)
-        step_code = self.visit(loop.step)
-
-        body_code = self.visit(loop.body)
-
-        code = (
-            f"for({ctr_symbol.dtype} {ctr} = {start_code};"
-            + f" {ctr} < {stop_code};"
-            + f" {ctr} += {step_code})\n"
-            + body_code
-        )
-        return self.indent(code)
-
-    @visit.case(PsConditional)
-    def conditional(self, node: PsConditional):
-        cond_code = self.visit(node.condition)
-        then_code = self.visit(node.branch_true)
-
-        code = f"if({cond_code})\n{then_code}"
-
-        if node.branch_false is not None:
-            else_code = self.visit(node.branch_false)
-            code += f"\nelse\n{else_code}"
-
-        return self.indent(code)
-
-    @visit.case(PsComment)
-    def comment(self, node: PsComment):
-        lines = list(node.lines)
-        lines[0] = "/* " + lines[0]
-        for i in range(1, len(lines)):
-            lines[i] = "   " + lines[i]
-        lines[-1] = lines[-1] + " */"
-        return self.indent("\n".join(lines))
+    def __call__(self, node: PsAstNode) -> str:
+        return self.visit(node, PrinterCtx())
+
+    def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
+        match node:
+            case PsKernelFunction(body):
+                params_spec = node.get_parameters()
+                params_str = ", ".join(
+                    f"{p.get_dtype().c_string()} {p.name}" for p in params_spec.params
+                )
+                decl = f"FUNC_PREFIX void {node.name} ({params_str})"
+                body_code = self.visit(body, pc)
+                return f"{decl}\n{body_code}"
+
+            case PsBlock(statements):
+                if not statements:
+                    return pc.indent("{ }")
+
+                pc.indent_level += self._indent_width
+                interior = "\n".join(self.visit(stmt, pc) for stmt in statements) + "\n"
+                pc.indent_level -= self._indent_width
+                return pc.indent("{\n") + interior + pc.indent("}\n")
+
+            case PsStatement(expr):
+                return pc.indent(f"{self.visit(expr, pc)};")
+
+            case PsDeclaration(lhs, rhs):
+                lhs_symb = lhs.symbol
+                lhs_dtype = lhs_symb.get_dtype()
+                rhs_code = self.visit(rhs, pc)
+
+                return pc.indent(
+                    f"{lhs_dtype.c_string()} {lhs_symb.name} = {rhs_code};"
+                )
+
+            case PsAssignment(lhs, rhs):
+                lhs_code = self.visit(lhs, pc)
+                rhs_code = self.visit(rhs, pc)
+                return pc.indent(f"{lhs_code} = {rhs_code};")
+
+            case PsLoop(ctr, start, stop, step, body):
+                ctr_symbol = ctr.symbol
+
+                start_code = self.visit(start, pc)
+                stop_code = self.visit(stop, pc)
+                step_code = self.visit(step, pc)
+                body_code = self.visit(body, pc)
+
+                code = (
+                    f"for({ctr_symbol.dtype} {ctr_symbol.name} = {start_code};"
+                    + f" {ctr.symbol.name} < {stop_code};"
+                    + f" {ctr.symbol.name} += {step_code})\n"
+                    + body_code
+                )
+                return pc.indent(code)
+
+            case PsConditional(condition, branch_true, branch_false):
+                cond_code = self.visit(condition, pc)
+                then_code = self.visit(branch_true, pc)
+
+                code = f"if({cond_code})\n{then_code}"
+
+                if branch_false is not None:
+                    else_code = self.visit(branch_false, pc)
+                    code += f"\nelse\n{else_code}"
+
+                return pc.indent(code)
+
+            case PsComment(lines):
+                lines_list = list(lines)
+                lines_list[0] = "/* " + lines_list[0]
+                for i in range(1, len(lines_list)):
+                    lines_list[i] = "   " + lines_list[i]
+                lines_list[-1] = lines_list[-1] + " */"
+                return pc.indent("\n".join(lines_list))
+
+            case PsSymbolExpr(symbol):
+                return symbol.name
+
+            case PsConstantExpr(constant):
+                dtype = constant.get_dtype()
+                if not isinstance(dtype, PsScalarType):
+                    raise EmissionError(
+                        "Cannot print literals for non-scalar constants."
+                    )
+
+                return dtype.create_literal(constant.value)
+
+            case PsVectorArrayAccess():
+                raise EmissionError("Cannot print vectorized array accesses")
+
+            case PsSubscript(base, index):
+                pc.push_op(Ops.Subscript, LR.Left)
+                base_code = self.visit(base, pc)
+                pc.pop_op()
+
+                pc.push_op(Ops.Weakest, LR.Middle)
+                index_code = self.visit(index, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript)
+
+            case PsLookup(aggr, member_name):
+                pc.push_op(Ops.Lookup, LR.Left)
+                aggr_code = self.visit(aggr, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"{aggr_code}.{member_name}", Ops.Lookup)
+
+            case PsCall(function, args):
+                pc.push_op(Ops.Weakest, LR.Middle)
+                args_string = ", ".join(self.visit(arg, pc) for arg in args)
+                pc.pop_op()
+
+                return pc.parenthesize(f"{function.name}({args_string})", Ops.Call)
+
+            case PsBinOp(op1, op2):
+                op_char, op = self._char_and_op(node)
+
+                pc.push_op(op, LR.Left)
+                op1_code = self.visit(op1, pc)
+                pc.switch_branch(LR.Right)
+                op2_code = self.visit(op2, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"{op1_code} {op_char} {op2_code}", op)
+
+            case PsNeg(operand):
+                pc.push_op(Ops.Neg, LR.Right)
+                operand_code = self.visit(operand, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"-{operand_code}", Ops.Neg)
+
+            case PsDeref(operand):
+                pc.push_op(Ops.Deref, LR.Right)
+                operand_code = self.visit(operand, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"*{operand_code}", Ops.Deref)
+
+            case PsAddressOf(operand):
+                pc.push_op(Ops.AddressOf, LR.Right)
+                operand_code = self.visit(operand, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"&{operand_code}", Ops.AddressOf)
+
+            case PsCast(target_type, operand):
+                pc.push_op(Ops.Cast, LR.Right)
+                operand_code = self.visit(operand, pc)
+                pc.pop_op()
+
+                type_str = target_type.c_string()
+                return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast)
+
+            case _:
+                raise NotImplementedError(f"Don't know how to print {node}")
+
+    def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]:
+        match node:
+            case PsAdd():
+                return ("+", Ops.Add)
+            case PsSub():
+                return ("-", Ops.Sub)
+            case PsMul():
+                return ("*", Ops.Mul)
+            case PsDiv():
+                return ("/", Ops.Div)
+            case _:
+                assert False
diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index 90bc09c5e..bddb5ca1a 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -14,13 +14,15 @@ TODO: Maybe add a way for the user to register additional functions
 TODO: Figure out the best way to describe function signatures and overloads for typing
 """
 
-from sys import intern
-import pymbolic.primitives as pb
-from abc import ABC, abstractmethod
+from __future__ import annotations
+from typing import Any, TYPE_CHECKING
+from abc import ABC
 from enum import Enum
 
 from .types import PsAbstractType
-from .typed_expressions import ExprOrConstant
+
+if TYPE_CHECKING:
+    from .ast.expressions import PsExpression
 
 
 class MathFunctions(Enum):
@@ -44,21 +46,31 @@ class MathFunctions(Enum):
         self.arg_count = arg_count
 
 
-class PsFunction(pb.FunctionSymbol, ABC):
+class PsFunction(ABC):
+    __match_args__ = ("name", "arg_count")
+
+    def __init__(self, name: str, num_args: int):
+        self._name = name
+        self._num_args = num_args
 
-    mapper_method = intern("map_ps_function")
+    @property
+    def name(self) -> str:
+        return self._name
 
     @property
-    @abstractmethod
     def arg_count(self) -> int:
         "Number of arguments this function takes"
+        return self._num_args
+
+    def __call__(self, *args: PsExpression) -> Any:
+        from .ast.expressions import PsCall
+
+        return PsCall(self, args)
 
 
 class CFunction(PsFunction):
     """A concrete C function."""
 
-    mapper_method = intern("map_c_function")
-
     def __init__(self, qualified_name: str, arg_count: int):
         self._qname = qualified_name
         self._arg_count = arg_count
@@ -75,9 +87,6 @@ class CFunction(PsFunction):
 class PsMathFunction(PsFunction):
     """Homogenously typed mathematical functions."""
 
-    init_arg_names = ("func",)
-    mapper_method = intern("map_math_function")
-
     def __init__(self, func: MathFunctions) -> None:
         self._func = func
 
@@ -93,11 +102,8 @@ class PsMathFunction(PsFunction):
 class Deref(PsFunction):
     """Dereferences a pointer."""
 
-    mapper_method = intern("map_deref")
-
-    @property
-    def arg_count(self) -> int:
-        return 1
+    def __init__(self):
+        super().__init__("deref", 1)
 
 
 deref = Deref()
@@ -106,22 +112,18 @@ deref = Deref()
 class AddressOf(PsFunction):
     """Take the address of an object"""
 
-    mapper_method = intern("map_address_of")
-
-    @property
-    def arg_count(self) -> int:
-        return 1
+    def __init__(self):
+        super().__init__("address_of", 1)
 
 
 address_of = AddressOf()
 
 
 class Cast(PsFunction):
-    mapper_method = intern("map_cast")
-
     """An unsafe C-style type cast"""
 
     def __init__(self, target_type: PsAbstractType):
+        super().__init__("cast", 1)
         self._target_type = target_type
 
     @property
@@ -133,5 +135,5 @@ class Cast(PsFunction):
         return self._target_type
 
 
-def cast(target_type: PsAbstractType, arg: ExprOrConstant):
+def cast(target_type: PsAbstractType, arg: PsExpression):
     return Cast(target_type)(arg)
diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py
index 680b48fe7..6004d1d17 100644
--- a/src/pystencils/backend/jit/cpu_extension_module.py
+++ b/src/pystencils/backend/jit/cpu_extension_module.py
@@ -11,11 +11,11 @@ import numpy as np
 
 from ..exceptions import PsInternalCompilerError
 from ..ast import PsKernelFunction
-from ..constraints import PsKernelConstraint
-from ..typed_expressions import PsTypedVariable
+from ..symbols import PsSymbol
+from ..constraints import PsKernelParamsConstraint
 from ..arrays import (
     PsLinearizedArray,
-    PsArrayAssocVar,
+    PsArrayAssocSymbol,
     PsArrayBasePointer,
     PsArrayShapeVar,
     PsArrayStrideVar,
@@ -210,8 +210,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
         self._array_extractions: dict[PsLinearizedArray, str] = dict()
         self._array_frees: dict[PsLinearizedArray, str] = dict()
 
-        self._array_assoc_var_extractions: dict[PsArrayAssocVar, str] = dict()
-        self._scalar_extractions: dict[PsTypedVariable, str] = dict()
+        self._array_assoc_var_extractions: dict[PsArrayAssocSymbol, str] = dict()
+        self._scalar_extractions: dict[PsSymbol, str] = dict()
 
         self._constraint_checks: list[str] = []
 
@@ -271,19 +271,19 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return self._array_buffers[arr]
 
-    def extract_scalar(self, variable: PsTypedVariable) -> str:
-        if variable not in self._scalar_extractions:
-            extract_func = self._scalar_extractor(variable.dtype)
+    def extract_scalar(self, symbol: PsSymbol) -> str:
+        if symbol not in self._scalar_extractions:
+            extract_func = self._scalar_extractor(symbol.get_dtype())
             code = self.TMPL_EXTRACT_SCALAR.format(
-                name=variable.name,
-                target_type=str(variable.dtype),
+                name=symbol.name,
+                target_type=str(symbol.dtype),
                 extract_function=extract_func,
             )
-            self._scalar_extractions[variable] = code
+            self._scalar_extractions[symbol] = code
 
-        return variable.name
+        return symbol.name
 
-    def extract_array_assoc_var(self, variable: PsArrayAssocVar) -> str:
+    def extract_array_assoc_var(self, variable: PsArrayAssocSymbol) -> str:
         if variable not in self._array_assoc_var_extractions:
             arr = variable.array
             buffer = self.extract_array(arr)
@@ -308,22 +308,19 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
 
         return variable.name
 
-    def extract_parameter(self, variable: PsTypedVariable):
-        match variable:
-            case PsArrayAssocVar():
-                self.extract_array_assoc_var(variable)
-            case PsTypedVariable():
-                self.extract_scalar(variable)
-            case _:
-                assert False, "Invalid variable encountered."
+    def extract_parameter(self, symbol: PsSymbol):
+        if isinstance(symbol, PsArrayAssocSymbol):
+            self.extract_array_assoc_var(symbol)
+        else:
+            self.extract_scalar(symbol)
 
-    def check_constraint(self, constraint: PsKernelConstraint):
-        variables = constraint.get_variables()
+    def check_constraint(self, constraint: PsKernelParamsConstraint):
+        variables = constraint.get_symbols()
 
         for var in variables:
             self.extract_parameter(var)
 
-        cond = constraint.print_c_condition()
+        cond = constraint.to_code()
 
         code = f"""
 if(!({cond}))
@@ -335,7 +332,7 @@ if(!({cond}))
 
         self._constraint_checks.append(code)
 
-    def call(self, kernel: PsKernelFunction, params: tuple[PsTypedVariable, ...]):
+    def call(self, kernel: PsKernelFunction, params: tuple[PsSymbol, ...]):
         param_list = ", ".join(p.name for p in params)
         self._call = f"{kernel.name} ({param_list});"
 
diff --git a/src/pystencils/backend/jit/legacy_cpu.py b/src/pystencils/backend/jit/legacy_cpu.py
index 6a7e63d14..771e8d1ca 100644
--- a/src/pystencils/backend/jit/legacy_cpu.py
+++ b/src/pystencils/backend/jit/legacy_cpu.py
@@ -1,3 +1,4 @@
+# mypy: ignore-errors
 r"""
 
 *pystencils* automatically searches for a compiler, so in most cases no explicit configuration is required.
diff --git a/src/pystencils/backend/kernelcreation/__init__.py b/src/pystencils/backend/kernelcreation/__init__.py
index c04cab827..c3846e2c9 100644
--- a/src/pystencils/backend/kernelcreation/__init__.py
+++ b/src/pystencils/backend/kernelcreation/__init__.py
@@ -38,7 +38,7 @@ and any parameter constraints introduced by later transformation passes.
 Analysis Passes
 ^^^^^^^^^^^^^^^
 
-Before the actual translation of the SymPy-based assignment collection to the pymbolic-based expression trees begins,
+Before the actual translation of the SymPy-based assignment collection to the backend's AST begins,
 the kernel's assignments are checked for consistency with the translator's prequesites.
 In this case, the `KernelAnalysis` pass
 checks the static single assignment-form (SSA) requirement and the absence of loop-carried dependencies.
@@ -59,8 +59,7 @@ the current iteration. It will only be instantiated in the form of a loop nest o
 Freeze and Typification
 ^^^^^^^^^^^^^^^^^^^^^^^
 
-The transformation of the SymPy-expressions to the backend's pymbolic-based expression trees is handled by
-`FreezeExpressions`.
+The transformation of the SymPy-expressions to the backend's expression trees is handled by `FreezeExpressions`.
 This class instantiates field accesses according to the iteration space, maps SymPy operators and functions to their
 backend instances if supported, and raises an exception if asked to translate something the backend can't handle.
 
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 115a50b9a..3bde2a135 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -1,16 +1,19 @@
 from __future__ import annotations
 
+from itertools import chain
 from types import EllipsisType
 
+from ...defaults import DEFAULTS
 from ...field import Field, FieldType
 from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType
+
+from ..symbols import PsSymbol
 from ..arrays import PsLinearizedArray
-from ..types import PsIntegerType, PsNumericType
+from ..types import PsAbstractType, PsIntegerType, PsNumericType
 from ..types.quick import make_type
-from ..constraints import PsKernelConstraint
+from ..constraints import PsKernelParamsConstraint
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
-from .defaults import Pymbolic as PbDefaults
 from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
 
 
@@ -47,12 +50,14 @@ class KernelCreationContext:
 
     def __init__(
         self,
-        default_dtype: PsNumericType = PbDefaults.numeric_dtype,
-        index_dtype: PsIntegerType = PbDefaults.index_dtype,
+        default_dtype: PsNumericType = DEFAULTS.numeric_dtype,
+        index_dtype: PsIntegerType = DEFAULTS.index_dtype,
     ):
         self._default_dtype = default_dtype
         self._index_dtype = index_dtype
-        self._constraints: list[PsKernelConstraint] = []
+        self._constraints: list[PsKernelParamsConstraint] = []
+
+        self._symbols: dict[str, PsSymbol] = dict()
 
         self._field_arrays: dict[Field, PsLinearizedArray] = dict()
         self._fields_collection = FieldsInKernel()
@@ -67,13 +72,43 @@ class KernelCreationContext:
     def index_dtype(self) -> PsIntegerType:
         return self._index_dtype
 
-    def add_constraints(self, *constraints: PsKernelConstraint):
+    def add_constraints(self, *constraints: PsKernelParamsConstraint):
         self._constraints += constraints
 
     @property
-    def constraints(self) -> tuple[PsKernelConstraint, ...]:
+    def constraints(self) -> tuple[PsKernelParamsConstraint, ...]:
         return tuple(self._constraints)
 
+    #   Symbols
+    def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol:
+        if name not in self._symbols:
+            symb = PsSymbol(name, None)
+            self._symbols[name] = symb
+        else:
+            symb = self._symbols[name]
+
+        if dtype is not None:
+            symb.apply_dtype(dtype)
+
+        return symb
+
+    def add_symbol(self, symbol: PsSymbol):
+        if symbol.name in self._symbols:
+            raise PsInternalCompilerError(f"Duplicate symbol: {symbol.name}")
+
+        self._symbols[symbol.name] = symbol
+
+    def replace_symbol(self, old: PsSymbol, new: PsSymbol):
+        if old.name != new.name:
+            raise PsInternalCompilerError(
+                "replace_symbol: Old and new symbol must have the same name"
+            )
+
+        if old.name not in self._symbols:
+            raise PsInternalCompilerError("Trying to replace an unknown symbol")
+
+        self._symbols[old.name] = new
+
     #   Fields and Arrays
 
     @property
@@ -175,6 +210,9 @@ class KernelCreationContext:
         )
 
         self._field_arrays[field] = arr
+        for symb in chain([arr.base_pointer], arr.shape, arr.strides):
+            if isinstance(symb, PsSymbol):
+                self.add_symbol(symb)
 
     def get_array(self, field: Field) -> PsLinearizedArray:
         """Retrieve the underlying array for a given field.
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index bdcf5120a..b69c6c2dd 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -1,8 +1,8 @@
 from typing import overload, cast
+from functools import reduce
+from operator import add, mul
 
 import sympy as sp
-import pymbolic.primitives as pb
-from pymbolic.interop.sympy import SympyToPymbolicMapper
 
 from ...sympyextensions import Assignment, AssignmentCollection
 from ...sympyextensions.typed_sympy import BasicType
@@ -10,17 +10,24 @@ from ...field import Field, FieldType
 
 from .context import KernelCreationContext
 
-from ..ast.nodes import (
+from ..ast.structural import (
+    PsAstNode,
     PsBlock,
     PsAssignment,
     PsDeclaration,
-    PsSymbolExpr,
-    PsLvalueExpr,
     PsExpression,
+    PsSymbolExpr,
+)
+from ..ast.expressions import (
+    PsArrayAccess,
+    PsVectorArrayAccess,
+    PsLookup,
+    PsCall,
+    PsConstantExpr,
 )
-from ..types import constify, make_type, PsStructType
-from ..typed_expressions import PsTypedVariable
-from ..arrays import PsArrayAccess, PsVectorArrayAccess
+
+from ..constants import PsConstant
+from ..types import constify, make_type, PsAbstractType, PsStructType
 from ..exceptions import PsInputError
 from ..functions import PsMathFunction, MathFunctions
 
@@ -29,47 +36,89 @@ class FreezeError(Exception):
     """Signifies an error during expression freezing."""
 
 
-class FreezeExpressions(SympyToPymbolicMapper):
+class FreezeExpressions:
     def __init__(self, ctx: KernelCreationContext):
         self._ctx = ctx
 
     @overload
-    def __call__(self, asms: AssignmentCollection) -> PsBlock:
+    def __call__(self, obj: AssignmentCollection) -> PsBlock:
         pass
 
     @overload
-    def __call__(self, expr: sp.Expr) -> PsExpression:
+    def __call__(self, obj: sp.Expr) -> PsExpression:
         pass
 
     @overload
-    def __call__(self, asm: Assignment) -> PsAssignment:
+    def __call__(self, obj: Assignment) -> PsAssignment:
         pass
 
-    def __call__(self, obj):
+    def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
         if isinstance(obj, AssignmentCollection):
-            return PsBlock([self.rec(asm) for asm in obj.all_assignments])
+            return PsBlock([self.visit(asm) for asm in obj.all_assignments])
         elif isinstance(obj, Assignment):
-            return cast(PsAssignment, self.rec(obj))
+            return cast(PsAssignment, self.visit(obj))
         elif isinstance(obj, sp.Expr):
-            return PsExpression(cast(pb.Expression, self.rec(obj)))
+            return cast(PsExpression, self.visit(obj))
         else:
             raise PsInputError(f"Don't know how to freeze {obj}")
 
-    def freeze_expression(self, expr: sp.Basic) -> pb.Expression:
-        return self.rec(expr)
+    def visit(self, node: sp.Basic) -> PsAstNode:
+        mro = list(type(node).__mro__)
+
+        while mro:
+            method_name = "map_" + mro.pop(0).__name__
+
+            try:
+                method = self.__getattribute__(method_name)
+            except AttributeError:
+                pass
+            else:
+                return method(node)
+
+        raise FreezeError(f"Don't know how to freeze expression {node}")
+
+    def visit_expr(self, expr: sp.Basic):
+        if not isinstance(expr, sp.Expr):
+            raise FreezeError(f"Cannot freeze {expr} to an expression")
+        return cast(PsExpression, self.visit(expr))
+
+    def freeze_expression(self, expr: sp.Expr) -> PsExpression:
+        return cast(PsExpression, self.visit(expr))
 
     def map_Assignment(self, expr: Assignment):  # noqa
-        lhs = self.rec(expr.lhs)
-        rhs = self.rec(expr.rhs)
+        lhs = self.visit(expr.lhs)
+        rhs = self.visit(expr.rhs)
 
-        if isinstance(lhs, pb.Variable):
-            return PsDeclaration(PsSymbolExpr(lhs), PsExpression(rhs))
+        assert isinstance(lhs, PsExpression)
+        assert isinstance(rhs, PsExpression)
+
+        if isinstance(lhs, PsSymbolExpr):
+            return PsDeclaration(lhs, rhs)
         elif isinstance(lhs, (PsArrayAccess, PsVectorArrayAccess)):  # todo
-            return PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
+            return PsAssignment(lhs, rhs)
         else:
             assert False, "That should not have happened."
 
-    def map_BasicType(self, expr: BasicType):
+    def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
+        symb = self._ctx.get_symbol(spsym.name)
+        return PsSymbolExpr(symb)
+
+    def map_Add(self, expr: sp.Add) -> PsExpression:
+        return reduce(add, (self.visit_expr(arg) for arg in expr.args))
+
+    def map_Mul(self, expr: sp.Mul) -> PsExpression:
+        return reduce(mul, (self.visit_expr(arg) for arg in expr.args))
+
+    def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
+        value = int(expr)
+        return PsConstantExpr(PsConstant(value))
+
+    def map_Rational(self, expr: sp.Rational) -> PsExpression:
+        num = PsConstantExpr(PsConstant(expr.numerator))
+        denom = PsConstantExpr(PsConstant(expr.denominator))
+        return num / denom
+
+    def map_type(self, expr: BasicType) -> PsAbstractType:
         #   TODO: This should not be necessary; the frontend should use the new type system.
         dtype = make_type(expr.numpy_dtype.type)
         if expr.const:
@@ -77,27 +126,25 @@ class FreezeExpressions(SympyToPymbolicMapper):
         else:
             return dtype
 
-    def map_FieldShapeSymbol(self, expr):
-        dtype = self.rec(expr.dtype)
-        return PsTypedVariable(expr.name, dtype)
-
     def map_TypedSymbol(self, expr):
-        dtype = self.rec(expr.dtype)
-        return PsTypedVariable(expr.name, dtype)
+        dtype = self.map_type(expr.dtype)
+        symb = self._ctx.get_symbol(expr.name, dtype)
+        return PsSymbolExpr(symb)
 
     def map_Access(self, access: Field.Access):
         field = access.field
         array = self._ctx.get_array(field)
         ptr = array.base_pointer
 
-        offsets: list[pb.Expression] = [self.rec(o) for o in access.offsets]
+        offsets: list[PsExpression] = [self.visit_expr(o) for o in access.offsets]
+        indices: list[PsExpression]
 
         if not access.is_absolute_access:
             match field.field_type:
                 case FieldType.GENERIC:
                     #   Add the iteration counters
                     offsets = [
-                        i + o
+                        PsExpression.make(i) + o
                         for i, o in zip(
                             self._ctx.get_iteration_space().spatial_indices, offsets
                         )
@@ -106,7 +153,9 @@ class FreezeExpressions(SympyToPymbolicMapper):
                     sparse_ispace = self._ctx.get_sparse_iteration_space()
                     #   Add sparse iteration counter to offset
                     assert len(offsets) == 1  # must have been checked by the context
-                    offsets = [offsets[0] + sparse_ispace.sparse_counter]
+                    offsets = [
+                        offsets[0] + PsExpression.make(sparse_ispace.sparse_counter)
+                    ]
                 case FieldType.BUFFER:
                     ispace = self._ctx.get_full_iteration_space()
                     compressed_ctr = ispace.compressed_counter()
@@ -123,35 +172,35 @@ class FreezeExpressions(SympyToPymbolicMapper):
         if isinstance(array.element_type, PsStructType):
             if isinstance(access.index, str):
                 struct_member_name = access.index
-                indices = [0]
+                indices = [PsExpression.make(PsConstant(0))]
             elif len(access.index) == 1 and isinstance(access.index[0], str):
                 struct_member_name = access.index[0]
-                indices = [0]
+                indices = [PsExpression.make(PsConstant(0))]
             else:
                 raise FreezeError(
                     f"Unsupported access into field with struct-type elements: {access}"
                 )
         else:
             struct_member_name = None
-            indices = [self.rec(i) for i in access.index]
+            indices = [self.visit_expr(i) for i in access.index]
             if not indices:
                 # For canonical representation, there must always be at least one index dimension
-                indices = [0]
+                indices = [PsExpression.make(PsConstant(0))]
 
         summands = tuple(
-            idx * stride
+            idx * PsExpression.make(stride)
             for idx, stride in zip(offsets + indices, array.strides, strict=True)
         )
 
-        index = summands[0] if len(summands) == 1 else pb.Sum(summands)
+        index = summands[0] if len(summands) == 1 else reduce(add, summands)
 
         if struct_member_name is not None:
-            # Produce a pb.Lookup here, don't check yet if the member name is valid. That's the typifier's job.
-            return pb.Lookup(PsArrayAccess(ptr, index), struct_member_name)
+            # Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job.
+            return PsLookup(PsArrayAccess(ptr, index), struct_member_name)
         else:
             return PsArrayAccess(ptr, index)
 
-    def map_Function(self, func: sp.Function) -> pb.Call:
+    def map_Function(self, func: sp.Function) -> PsCall:
         """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
 
         SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
@@ -174,5 +223,5 @@ class FreezeExpressions(SympyToPymbolicMapper):
             case _:
                 raise FreezeError(f"Unsupported function: {func}")
 
-        args = tuple(self.rec(arg) for arg in func.args)
-        return pb.Call(func_symbol, args)
+        args = tuple(self.visit_expr(arg) for arg in func.args)
+        return PsCall(func_symbol, args)
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index 63593c548..7d99b47ca 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -7,19 +7,16 @@ from operator import mul
 
 import sympy as sp
 
+from ...defaults import DEFAULTS
 from ...sympyextensions import AssignmentCollection
 from ...field import Field, FieldType
 
-from ..typed_expressions import (
-    PsTypedVariable,
-    VarOrConstant,
-    ExprOrConstant,
-    PsTypedConstant,
-)
+from ..symbols import PsSymbol
+from ..constants import PsConstant
+from ..ast.expressions import PsExpression, PsConstantExpr
 from ..arrays import PsLinearizedArray
 from ..ast.util import failing_cast
 from ..types import PsStructType, constify
-from .defaults import Pymbolic as Defaults
 from ..exceptions import PsInputError, KernelConstraintsError
 
 if TYPE_CHECKING:
@@ -40,14 +37,14 @@ class IterationSpace(ABC):
        spatial indices.
     """
 
-    def __init__(self, spatial_indices: Sequence[PsTypedVariable]):
+    def __init__(self, spatial_indices: Sequence[PsSymbol]):
         if len(spatial_indices) == 0:
             raise ValueError("Iteration space must be at least one-dimensional.")
 
         self._spatial_indices = tuple(spatial_indices)
 
     @property
-    def spatial_indices(self) -> tuple[PsTypedVariable, ...]:
+    def spatial_indices(self) -> tuple[PsSymbol, ...]:
         return self._spatial_indices
 
     @property
@@ -66,10 +63,10 @@ class FullIterationSpace(IterationSpace):
 
     @dataclass
     class Dimension:
-        start: ExprOrConstant
-        stop: ExprOrConstant
-        step: ExprOrConstant
-        counter: PsTypedVariable
+        start: PsExpression
+        stop: PsExpression
+        step: PsExpression
+        counter: PsSymbol
 
     @staticmethod
     def create_with_ghost_layers(
@@ -83,8 +80,8 @@ class FullIterationSpace(IterationSpace):
         dim = archetype_field.spatial_dimensions
 
         counters = [
-            PsTypedVariable(name, ctx.index_dtype)
-            for name in Defaults.spatial_counter_names[:dim]
+            ctx.get_symbol(name, ctx.index_dtype)
+            for name in DEFAULTS.spatial_counter_names[:dim]
         ]
 
         if isinstance(ghost_layers, int):
@@ -96,12 +93,12 @@ class FullIterationSpace(IterationSpace):
                 ((gl, gl) if isinstance(gl, int) else gl) for gl in ghost_layers
             ]
 
-        one = PsTypedConstant(1, ctx.index_dtype)
+        one = PsConstantExpr(PsConstant(1, ctx.index_dtype))
 
         ghost_layer_exprs = [
             (
-                PsTypedConstant(gl_left, ctx.index_dtype),
-                PsTypedConstant(gl_right, ctx.index_dtype),
+                PsConstantExpr(PsConstant(gl_left, ctx.index_dtype)),
+                PsConstantExpr(PsConstant(gl_right, ctx.index_dtype)),
             )
             for (gl_left, gl_right) in ghost_layers_spec
         ]
@@ -109,7 +106,9 @@ class FullIterationSpace(IterationSpace):
         spatial_shape = archetype_array.shape[:dim]
 
         dimensions = [
-            FullIterationSpace.Dimension(gl_left, shape - gl_right, one, ctr)
+            FullIterationSpace.Dimension(
+                gl_left, PsExpression.make(shape) - gl_right, one, ctr
+            )
             for (gl_left, gl_right), shape, ctr in zip(
                 ghost_layer_exprs, spatial_shape, counters, strict=True
             )
@@ -137,8 +136,8 @@ class FullIterationSpace(IterationSpace):
             )
 
         counters = [
-            PsTypedVariable(name, ctx.index_dtype)
-            for name in Defaults.spatial_counter_names[:dim]
+            ctx.get_symbol(name, ctx.index_dtype)
+            for name in DEFAULTS.spatial_counter_names[:dim]
         ]
 
         from .freeze import FreezeExpressions
@@ -147,9 +146,9 @@ class FullIterationSpace(IterationSpace):
         freeze = FreezeExpressions(ctx)
         typifier = Typifier(ctx)
 
-        def to_pb(expr):
+        def expr_convert(expr) -> PsExpression:
             if isinstance(expr, int):
-                return PsTypedConstant(expr, ctx.index_dtype)
+                return PsConstantExpr(PsConstant(expr, ctx.index_dtype))
             elif isinstance(expr, sp.Expr):
                 return typifier.typify_expression(
                     freeze.freeze_expression(expr), ctx.index_dtype
@@ -157,13 +156,15 @@ class FullIterationSpace(IterationSpace):
             else:
                 raise ValueError(f"Invalid entry in slice: {expr}")
 
-        def to_dim(slic: slice, size: VarOrConstant, ctr: PsTypedVariable):
-            start = to_pb(slic.start if slic.start is not None else 0)
-            stop = to_pb(slic.stop) if slic.stop is not None else size
-            step = to_pb(slic.step if slic.step is not None else 1)
+        def to_dim(slic: slice, size: PsSymbol | PsConstant, ctr: PsSymbol):
+            size_expr = PsExpression.make(size)
+
+            start = expr_convert(slic.start if slic.start is not None else 0)
+            stop = expr_convert(slic.stop) if slic.stop is not None else size_expr
+            step = expr_convert(slic.step if slic.step is not None else 1)
 
             if isinstance(slic.stop, int) and slic.stop < 0:
-                stop = size + stop
+                stop = size_expr + stop  # todo
 
             return FullIterationSpace.Dimension(start, stop, step, ctr)
 
@@ -202,23 +203,24 @@ class FullIterationSpace(IterationSpace):
     def steps(self):
         return (dim.step for dim in self._dimensions)
 
-    def actual_iterations(self, dimension: int | None = None) -> ExprOrConstant:
+    def actual_iterations(self, dimension: int | None = None) -> PsExpression:
         if dimension is None:
             return reduce(
                 mul, (self.actual_iterations(d) for d in range(len(self.dimensions)))
             )
         else:
             dim = self.dimensions[dimension]
-            one = PsTypedConstant(1, self._ctx.index_dtype)
+            one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
             return one + (dim.stop - dim.start - one) / dim.step
 
-    def compressed_counter(self) -> ExprOrConstant:
+    def compressed_counter(self) -> PsExpression:
         """Expression counting the actual number of items processed at the iteration defined by the counter tuple.
 
         Used primarily for indexing buffers."""
         actual_iters = [self.actual_iterations(d) for d in range(self.dim)]
         compressed_counters = [
-            (dim.counter - dim.start) / dim.step for dim in self.dimensions
+            (PsExpression.make(dim.counter) - dim.start) / dim.step
+            for dim in self.dimensions
         ]
         compressed_idx = compressed_counters[0]
         for ctr, iters in zip(compressed_counters[1:], actual_iters[1:]):
@@ -229,10 +231,10 @@ class FullIterationSpace(IterationSpace):
 class SparseIterationSpace(IterationSpace):
     def __init__(
         self,
-        spatial_indices: Sequence[PsTypedVariable],
+        spatial_indices: Sequence[PsSymbol],
         index_list: PsLinearizedArray,
         coordinate_members: Sequence[PsStructType.Member],
-        sparse_counter: PsTypedVariable,
+        sparse_counter: PsSymbol,
     ):
         super().__init__(spatial_indices)
         self._index_list = index_list
@@ -248,7 +250,7 @@ class SparseIterationSpace(IterationSpace):
         return self._coord_members
 
     @property
-    def sparse_counter(self) -> PsTypedVariable:
+    def sparse_counter(self) -> PsSymbol:
         return self._sparse_counter
 
 
@@ -303,7 +305,7 @@ def create_sparse_iteration_space(
     dim = archetype_field.spatial_dimensions
     coord_members = [
         PsStructType.Member(name, ctx.index_dtype)
-        for name in Defaults._index_struct_coordinate_names[:dim]
+        for name in DEFAULTS._index_struct_coordinate_names[:dim]
     ]
 
     #   Determine index field
@@ -323,11 +325,11 @@ def create_sparse_iteration_space(
         )
 
     spatial_counters = [
-        PsTypedVariable(name, constify(ctx.index_dtype))
-        for name in Defaults.spatial_counter_names[:dim]
+        ctx.get_symbol(name, constify(ctx.index_dtype))
+        for name in DEFAULTS.spatial_counter_names[:dim]
     ]
 
-    sparse_counter = PsTypedVariable(Defaults.sparse_counter_name, ctx.index_dtype)
+    sparse_counter = ctx.get_symbol(DEFAULTS.sparse_counter_name, ctx.index_dtype)
 
     return SparseIterationSpace(
         spatial_counters, idx_arr, coord_members, sparse_counter
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index d251f0acd..9e49ef758 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -1,15 +1,18 @@
 from __future__ import annotations
 
-from typing import TypeVar, Any
-
-import pymbolic.primitives as pb
-from pymbolic.mapper import Mapper
+from typing import TypeVar
 
 from .context import KernelCreationContext
 from ..types import PsAbstractType, PsNumericType, PsStructType, deconstify
-from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
-from ..arrays import PsArrayAccess, PsVectorArrayAccess
-from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
+from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment
+from ..ast.expressions import (
+    PsSymbolExpr,
+    PsConstantExpr,
+    PsBinOp,
+    PsArrayAccess,
+    PsLookup,
+    PsCall,
+)
 from ..functions import PsMathFunction
 
 __all__ = ["Typifier"]
@@ -22,55 +25,55 @@ class TypificationError(Exception):
 NodeT = TypeVar("NodeT", bound=PsAstNode)
 
 
-class DeferredTypedConstant(PsTypedConstant):
-    """Special subclass for constants whose types cannot be determined yet at the time of their creation.
-
-    Outside of the typifier, a DeferredTypedConstant acts exactly the same way as a PsTypedConstant.
-    """
-
-    def __init__(self, value: Any):
-        self._value_deferred = value
-
-    def resolve(self, dtype: PsNumericType):
-        super().__init__(self._value_deferred, dtype)
-
-
 class TypeContext:
-    def __init__(self, target_type: PsAbstractType | None):
+    def __init__(self, target_type: PsAbstractType | None = None):
         self._target_type = deconstify(target_type) if target_type is not None else None
-        self._deferred_constants: list[DeferredTypedConstant] = []
+        self._deferred_constants: list[PsConstantExpr] = []
 
-    def make_constant(self, value: Any) -> PsTypedConstant:
+    def typify_constant(self, constexpr: PsConstantExpr) -> None:
         if self._target_type is None:
-            dc = DeferredTypedConstant(value)
-            self._deferred_constants.append(dc)
-            return dc
+            self._deferred_constants.append(constexpr)
         elif not isinstance(self._target_type, PsNumericType):
             raise TypificationError(
                 f"Can't typify constant with non-numeric type {self._target_type}"
             )
         else:
-            return PsTypedConstant(value, self._target_type)
+            constexpr.constant.apply_dtype(self._target_type)
 
-    def apply(self, target_type: PsAbstractType):
-        assert self._target_type is None, "Type context was already resolved"
-        self._target_type = deconstify(target_type)
+    def apply_and_check(self, expr: PsExpression, expr_type: PsAbstractType):
+        """
+        If no target type has been set yet, establishes expr_type as the target type
+        and typifies all deferred expressions.
 
-        for dc in self._deferred_constants:
-            if not isinstance(self._target_type, PsNumericType):
-                raise TypificationError(
-                    f"Can't typify constant with non-numeric type {self._target_type}"
-                )
-            dc.resolve(self._target_type)
+        Otherwise, checks if expression type and target type are compatible.
+        """
+        expr_type = deconstify(expr_type)
 
-        self._deferred_constants = []
+        if self._target_type is None:
+            self._target_type = expr_type
+
+            for dc in self._deferred_constants:
+                if not isinstance(self._target_type, PsNumericType):
+                    raise TypificationError(
+                        f"Can't typify constant with non-numeric type {self._target_type}"
+                    )
+                dc.constant.apply_dtype(self._target_type)
+
+            self._deferred_constants = []
+
+        elif expr_type != self._target_type:
+            raise TypificationError(
+                f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
+                f"  Expression type: {expr_type}\n"
+                f"      Target type: {self._target_type}"
+            )
 
     @property
     def target_type(self) -> PsAbstractType | None:
         return self._target_type
 
 
-class Typifier(Mapper):
+class Typifier:
     """Typifier for untyped expressions.
 
     The typifier, when called with an AST node, will attempt to figure out
@@ -117,28 +120,19 @@ class Typifier(Mapper):
         self._ctx = ctx
 
     def __call__(self, node: NodeT) -> NodeT:
-        match node:
-            case PsBlock([*statements]):
-                node.statements = [self(s) for s in statements]
-
-            case PsExpression(expr):
-                node.expression = self.rec(expr, TypeContext(None))
-
-            case PsAssignment(lhs, rhs):
-                tc = TypeContext(None)
-                #   LHS defines target type; type context carries it to RHS
-                new_lhs = self.rec(lhs.expression, tc)
-                assert tc.target_type is not None
-                new_rhs = self.rec(rhs.expression, tc)
-
-                node.lhs.expression = new_lhs
-                node.rhs.expression = new_rhs
-
-            case unknown:
-                raise NotImplementedError(f"Don't know how to typify {unknown}")
-
+        if isinstance(node, PsExpression):
+            self.visit_expr(node, TypeContext())
+        else:
+            self.visit(node)
         return node
 
+    def typify_expression(
+        self, expr: PsExpression, target_type: PsNumericType | None = None
+    ) -> PsExpression:
+        tc = TypeContext(target_type)
+        self.visit_expr(expr, tc)
+        return expr
+
     """
     def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant
 
@@ -146,96 +140,84 @@ class Typifier(Mapper):
     They shall return the typified expression, or throw `TypificationError` if typification fails.
     """
 
-    def typify_expression(
-        self, expr: Any, target_type: PsNumericType | None = None
-    ) -> ExprOrConstant:
-        tc = TypeContext(target_type)
-        return self.rec(expr, tc)
-
-    #   Leaf nodes: Variables, Typed Variables, Constants and TypedConstants
-
-    def map_typed_variable(self, var: PsTypedVariable, tc: TypeContext):
-        self._apply_target_type(var, var.dtype, tc)
-        return var
-
-    def map_variable(self, var: pb.Variable, tc: TypeContext) -> PsTypedVariable:
-        dtype = self._ctx.default_dtype
-        typed_var = PsTypedVariable(var.name, dtype)
-        self._apply_target_type(typed_var, dtype, tc)
-        return typed_var
-
-    def map_constant(self, value: Any, tc: TypeContext) -> PsTypedConstant:
-        if isinstance(value, PsTypedConstant):
-            self._apply_target_type(value, value.dtype, tc)
-            return value
-
-        return tc.make_constant(value)
-
-    #   Array Accesses and Lookups
-
-    def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess:
-        self._apply_target_type(access, access.dtype, tc)
-        index = self.rec(access.index_tuple[0], TypeContext(self._ctx.index_dtype))
-        return PsArrayAccess(access.base_ptr, index)
-
-    def map_vector_array_access(
-        self, access: PsVectorArrayAccess, tc: TypeContext
-    ) -> PsVectorArrayAccess:
-        self._apply_target_type(access, access.dtype, tc)
-        base_index = self.rec(access.base_index, TypeContext(self._ctx.index_dtype))
-        return PsVectorArrayAccess(
-            access.base_ptr, base_index, access.dtype.vector_entries, access.stride
-        )
-
-    def map_lookup(self, lookup: pb.Lookup, tc: TypeContext) -> pb.Lookup:
-        aggr_tc = TypeContext(None)
-        aggregate = self.rec(lookup.aggregate, aggr_tc)
-        aggr_type = aggr_tc.target_type
-
-        if not isinstance(aggr_type, PsStructType):
-            raise TypificationError("Aggregate type of lookup was not a struct type.")
-
-        member = aggr_type.get_member(lookup.name)
-        if member is None:
-            raise TypificationError(
-                f"Aggregate of type {aggr_type} does not have a member {member}."
-            )
-
-        self._apply_target_type(lookup, member.dtype, tc)
-        return pb.Lookup(aggregate, member.name)
-
-    #   Arithmetic Expressions
+    def visit(self, node: PsAstNode) -> None:
+        """Recursive processing of structural nodes"""
+        match node:
+            case PsBlock([*statements]):
+                for s in statements:
+                    self.visit(s)
 
-    def map_sum(self, expr: pb.Sum, tc: TypeContext) -> pb.Sum:
-        return pb.Sum(tuple(self.rec(c, tc) for c in expr.children))
+            case PsAssignment(lhs, rhs):
+                tc = TypeContext()
+                #   LHS defines target type; type context carries it to RHS
+                self.visit_expr(lhs, tc)
+                assert tc.target_type is not None
+                self.visit_expr(rhs, tc)
 
-    def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product:
-        return pb.Product(tuple(self.rec(c, tc) for c in expr.children))
+            case PsLoop(ctr, start, stop, step, body):
+                if ctr.symbol.dtype is None:
+                    ctr.symbol.apply_dtype(self._ctx.index_dtype)
 
-    def map_quotient(self, expr: pb.Quotient, tc: TypeContext) -> pb.Quotient:
-        return pb.Quotient(self.rec(expr.num, tc), self.rec(expr.den, tc))
+                tc = TypeContext(ctr.symbol.dtype)
+                self.visit_expr(start, tc)
+                self.visit_expr(stop, tc)
+                self.visit_expr(step, tc)
 
-    #   Functions
+                self.visit(body)
 
-    def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
-        func = expr.function
-        args = expr.parameters
-        match func:
-            case PsMathFunction():
-                return pb.Call(func, tuple(self.rec(arg, tc) for arg in args))
             case _:
-                raise TypificationError(f"Don't know how to typify calls to {func}")
-
-    #   Internals
+                raise NotImplementedError(f"Can't typify {node}")
+
+    def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
+        """Recursive processing of expression nodes"""
+        match expr:
+            case PsSymbolExpr(symb):
+                if symb.dtype is None:
+                    dtype = self._ctx.default_dtype
+                    symb.apply_dtype(dtype)
+                tc.apply_and_check(expr, symb.get_dtype())
+
+            case PsConstantExpr(constant):
+                if constant.dtype is not None:
+                    tc.apply_and_check(expr, constant.get_dtype())
+                else:
+                    tc.typify_constant(expr)
+
+            case PsArrayAccess(_, idx):
+                tc.apply_and_check(expr, expr.dtype)
+                self.visit_expr(idx, TypeContext(self._ctx.index_dtype))
+
+            case PsLookup(aggr, member_name):
+                aggr_tc = TypeContext(None)
+                self.visit_expr(aggr, aggr_tc)
+                aggr_type = aggr_tc.target_type
+
+                if not isinstance(aggr_type, PsStructType):
+                    raise TypificationError(
+                        "Aggregate type of lookup was not a struct type."
+                    )
+
+                member = aggr_type.get_member(member_name)
+                if member is None:
+                    raise TypificationError(
+                        f"Aggregate of type {aggr_type} does not have a member {member}."
+                    )
+
+                tc.apply_and_check(expr, member.dtype)
+
+            case PsBinOp(op1, op2):
+                self.visit_expr(op1, tc)
+                self.visit_expr(op2, tc)
+
+            case PsCall(function, args):
+                match function:
+                    case PsMathFunction():
+                        for arg in args:
+                            self.visit_expr(arg, tc)
+                    case _:
+                        raise TypificationError(
+                            f"Don't know how to typify calls to {function}"
+                        )
 
-    def _apply_target_type(
-        self, expr: ExprOrConstant, expr_type: PsAbstractType, tc: TypeContext
-    ):
-        if tc.target_type is None:
-            tc.apply(expr_type)
-        elif deconstify(expr_type) != tc.target_type:
-            raise TypificationError(
-                f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
-                f"  Expression type: {expr_type}\n"
-                f"      Target type: {tc.target_type}"
-            )
+            case _:
+                raise NotImplementedError(f"Can't typify {expr}")
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index e9fe2ef14..b3a49cb65 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -1,8 +1,6 @@
 from typing import Sequence
 from abc import ABC, abstractmethod
 
-import pymbolic.primitives as pb
-
 from .platform import Platform
 
 from ..kernelcreation.iteration_space import (
@@ -11,10 +9,16 @@ from ..kernelcreation.iteration_space import (
     SparseIterationSpace,
 )
 
-from ..ast import PsDeclaration, PsSymbolExpr, PsExpression, PsLoop, PsBlock
+from ..constants import PsConstant
+from ..ast.structural import PsDeclaration, PsLoop, PsBlock
+from ..ast.expressions import (
+    PsSymbolExpr,
+    PsExpression,
+    PsArrayAccess,
+    PsVectorArrayAccess,
+    PsLookup,
+)
 from ..types import PsVectorType, PsCustomType
-from ..typed_expressions import PsTypedConstant
-from ..arrays import PsArrayAccess, PsVectorArrayAccess
 from ..transformations.vector_intrinsics import IntrinsicOps
 
 
@@ -48,9 +52,9 @@ class GenericCpu(Platform):
         for dimension in dimensions[::-1]:
             loop = PsLoop(
                 PsSymbolExpr(dimension.counter),
-                PsExpression(dimension.start),
-                PsExpression(dimension.stop),
-                PsExpression(dimension.step),
+                dimension.start,
+                dimension.stop,
+                dimension.step,
                 outer_block,
             )
             outer_block = PsBlock([loop])
@@ -61,10 +65,12 @@ class GenericCpu(Platform):
         mappings = [
             PsDeclaration(
                 PsSymbolExpr(ctr),
-                PsExpression(
+                PsLookup(
                     PsArrayAccess(
-                        ispace.index_list.base_pointer, ispace.sparse_counter
-                    ).a.__getattr__(coord.name)
+                        ispace.index_list.base_pointer,
+                        PsExpression.make(ispace.sparse_counter),
+                    ),
+                    coord.name,
                 ),
             )
             for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members)
@@ -74,9 +80,9 @@ class GenericCpu(Platform):
 
         loop = PsLoop(
             PsSymbolExpr(ispace.sparse_counter),
-            PsExpression(PsTypedConstant(0, self._ctx.index_dtype)),
-            PsExpression(ispace.index_list.shape[0]),
-            PsExpression(PsTypedConstant(1, self._ctx.index_dtype)),
+            PsExpression.make(PsConstant(0, self._ctx.index_dtype)),
+            PsExpression.make(ispace.index_list.shape[0]),
+            PsExpression.make(PsConstant(1, self._ctx.index_dtype)),
             body,
         )
 
@@ -95,26 +101,24 @@ class GenericVectorCpu(GenericCpu, ABC):
         or raise an `IntrinsicsError` if type is not supported."""
 
     @abstractmethod
-    def constant_vector(self, c: PsTypedConstant) -> pb.Expression:
+    def constant_vector(self, c: PsConstant) -> PsExpression:
         """Return an expression that initializes a constant vector,
         or raise an `IntrinsicsError` if not supported."""
 
     @abstractmethod
     def op_intrinsic(
-        self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[pb.Expression]
-    ) -> pb.Expression:
+        self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[PsExpression]
+    ) -> PsExpression:
         """Return an expression intrinsically invoking the given operation
         on the given arguments with the given vector type,
         or raise an `IntrinsicsError` if not supported."""
 
     @abstractmethod
-    def vector_load(self, acc: PsVectorArrayAccess) -> pb.Expression:
+    def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression:
         """Return an expression intrinsically performing a vector load,
         or raise an `IntrinsicsError` if not supported."""
 
     @abstractmethod
-    def vector_store(
-        self, acc: PsVectorArrayAccess, arg: pb.Expression
-    ) -> pb.Expression:
+    def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression:
         """Return an expression intrinsically performing a vector store,
         or raise an `IntrinsicsError` if not supported."""
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index d6bf54a57..7c6d3a2ee 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -1,6 +1,6 @@
 from abc import ABC, abstractmethod
 
-from ..ast import PsBlock
+from ..ast.structural import PsBlock
 
 from ..kernelcreation.context import KernelCreationContext
 from ..kernelcreation.iteration_space import IterationSpace
diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py
index c4a9f0d4c..f0e42bccb 100644
--- a/src/pystencils/backend/platforms/x86.py
+++ b/src/pystencils/backend/platforms/x86.py
@@ -3,13 +3,10 @@ from enum import Enum
 from functools import cache
 from typing import Sequence
 
-from pymbolic.primitives import Expression
-
-from ..arrays import PsVectorArrayAccess
+from ..ast.expressions import PsExpression, PsVectorArrayAccess, PsAddressOf, PsSubscript
 from ..transformations.vector_intrinsics import IntrinsicOps
-from ..typed_expressions import PsTypedConstant
 from ..types import PsCustomType, PsVectorType
-from ..functions import address_of
+from ..constants import PsConstant
 
 from .generic_cpu import GenericVectorCpu, IntrinsicsError
 
@@ -118,7 +115,7 @@ class X86VectorCpu(GenericVectorCpu):
             )
         return PsCustomType(f"__m{vector_type.width}{suffix}")
 
-    def constant_vector(self, c: PsTypedConstant) -> Expression:
+    def constant_vector(self, c: PsConstant) -> PsExpression:
         vtype = c.dtype
         assert isinstance(vtype, PsVectorType)
 
@@ -130,22 +127,22 @@ class X86VectorCpu(GenericVectorCpu):
         return set_func(*values)
 
     def op_intrinsic(
-        self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[Expression]
-    ) -> Expression:
+        self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[PsExpression]
+    ) -> PsExpression:
         func = _x86_op_intrin(self._vector_arch, op, vtype)
         return func(*args)
 
-    def vector_load(self, acc: PsVectorArrayAccess) -> Expression:
+    def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression:
         if acc.stride == 1:
             load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
-            return load_func(address_of(acc.base_ptr[acc.base_index]))
+            return load_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)))
         else:
             raise NotImplementedError("Gather loads not implemented yet.")
 
-    def vector_store(self, acc: PsVectorArrayAccess, arg: Expression) -> Expression:
+    def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression:
         if acc.stride == 1:
             store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
-            return store_func(address_of(acc.base_ptr[acc.base_index]), arg)
+            return store_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), arg)
         else:
             raise NotImplementedError("Scatter stores not implemented yet.")
 
diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py
new file mode 100644
index 000000000..d4ff5eafd
--- /dev/null
+++ b/src/pystencils/backend/symbols.py
@@ -0,0 +1,53 @@
+from .types import PsAbstractType, PsTypeError
+from .exceptions import PsInternalCompilerError
+
+
+class PsSymbol:
+    """A mutable symbol with name and data type.
+
+    Be advised to not create objects of this class directly unless you know what you are doing;
+    instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`.
+    This way, the context can keep track of all symbols used in the translation run,
+    and uniqueness of symbols is ensured.
+    """
+
+    __match_args__ = ("name", "dtype")
+
+    def __init__(self, name: str, dtype: PsAbstractType | None = None):
+        self._name = name
+        self._dtype = dtype
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    @property
+    def dtype(self) -> PsAbstractType | None:
+        return self._dtype
+
+    @dtype.setter
+    def dtype(self, value: PsAbstractType):
+        self._dtype = value
+
+    def apply_dtype(self, dtype: PsAbstractType):
+        """Apply the given data type to this symbol,
+        raising a TypeError if it conflicts with a previously set data type."""
+
+        if self._dtype is not None and self._dtype != dtype:
+            raise PsTypeError(
+                f"Incompatible symbol data types: {self._dtype} and {dtype}"
+            )
+
+        self._dtype = dtype
+
+    def get_dtype(self) -> PsAbstractType:
+        if self._dtype is None:
+            raise PsInternalCompilerError("Symbol had no type assigned yet")
+        return self._dtype
+
+    def __str__(self) -> str:
+        dtype_str = "<untyped>" if self._dtype is None else str(self._dtype)
+        return f"{self._name}: {dtype_str}"
+
+    def __repr__(self) -> str:
+        return str(self)
diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py
index 97c50b7eb..8b039a1dc 100644
--- a/src/pystencils/backend/transformations/erase_anonymous_structs.py
+++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py
@@ -1,22 +1,22 @@
 from __future__ import annotations
 
-from typing import TypeVar
-
-import pymbolic.primitives as pb
-from pymbolic.mapper import IdentityMapper
-
 from ..kernelcreation.context import KernelCreationContext
 
-from ..ast import PsAstNode, PsExpression
-from ..arrays import PsArrayAccess, TypeErasedBasePointer
-from ..typed_expressions import PsTypedConstant
+from ..constants import PsConstant
+from ..ast.structural import PsAstNode
+from ..ast.expressions import (
+    PsArrayAccess,
+    PsLookup,
+    PsExpression,
+    PsDeref,
+    PsAddressOf,
+    PsCast,
+)
+from ..arrays import PsArrayBasePointer, TypeErasedBasePointer
 from ..types import PsStructType, PsPointerType
-from ..functions import deref, address_of, Cast
 
-NodeT = TypeVar("NodeT", bound=PsAstNode)
 
-
-class EraseAnonymousStructTypes(IdentityMapper):
+class EraseAnonymousStructTypes:
     """Lower anonymous struct arrays to a byte-array representation.
 
     For arrays whose element type is an anonymous struct, the struct type is erased from the base pointer,
@@ -27,18 +27,29 @@ class EraseAnonymousStructTypes(IdentityMapper):
     def __init__(self, ctx: KernelCreationContext) -> None:
         self._ctx = ctx
 
-    def __call__(self, node: NodeT) -> NodeT:
+        self._substitutions: dict[PsArrayBasePointer, TypeErasedBasePointer] = dict()
+
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        self._substitutions = dict()
+
+        node = self.visit(node)
+
+        for old, new in self._substitutions.items():
+            self._ctx.replace_symbol(old, new)
+
+        return node
+
+    def visit(self, node: PsAstNode) -> PsAstNode:
         match node:
-            case PsExpression(expr):
+            case PsLookup():
                 # descend into expr
-                node.expression = self.rec(expr)
-            case other:
-                for c in other.children:
-                    self(c)
+                return self.handle_lookup(node)
+            case _:
+                node.children = [self.visit(c) for c in node.children]
 
         return node
 
-    def map_lookup(self, lookup: pb.Lookup) -> pb.Expression:
+    def handle_lookup(self, lookup: PsLookup) -> PsExpression:
         aggr = lookup.aggregate
         if not isinstance(aggr, PsArrayAccess):
             return lookup
@@ -54,12 +65,19 @@ class EraseAnonymousStructTypes(IdentityMapper):
         struct_size = struct_type.itemsize
 
         bp = aggr.base_ptr
-        type_erased_bp = TypeErasedBasePointer(bp.name, arr)
-        base_index = aggr.index_tuple[0] * PsTypedConstant(
-            struct_size, self._ctx.index_dtype
+
+        #   Need to keep track of base pointers already seen, since symbols must be unique
+        if bp not in self._substitutions:
+            type_erased_bp = TypeErasedBasePointer(bp.name, arr)
+            self._substitutions[bp] = type_erased_bp
+        else:
+            type_erased_bp = self._substitutions[bp]
+
+        base_index = aggr.index * PsExpression.make(
+            PsConstant(struct_size, self._ctx.index_dtype)
         )
 
-        member_name = lookup.name
+        member_name = lookup.member_name
         member = struct_type.get_member(member_name)
         assert member is not None
 
@@ -68,9 +86,11 @@ class EraseAnonymousStructTypes(IdentityMapper):
         assert np_struct.fields is not None
         member_offset = np_struct.fields[member_name][1]
 
-        byte_index = base_index + PsTypedConstant(member_offset, self._ctx.index_dtype)
+        byte_index = base_index + PsExpression.make(
+            PsConstant(member_offset, self._ctx.index_dtype)
+        )
         type_erased_access = PsArrayAccess(type_erased_bp, byte_index)
 
-        cast = Cast(PsPointerType(member.dtype))
-
-        return deref(cast(address_of(type_erased_access)))
+        return PsDeref(
+            PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access))
+        )
diff --git a/src/pystencils/backend/transformations/vector_intrinsics.py b/src/pystencils/backend/transformations/vector_intrinsics.py
index 03c290150..d2f58d2c3 100644
--- a/src/pystencils/backend/transformations/vector_intrinsics.py
+++ b/src/pystencils/backend/transformations/vector_intrinsics.py
@@ -1,14 +1,20 @@
 from __future__ import annotations
-from typing import TypeVar, TYPE_CHECKING
+from typing import TypeVar, TYPE_CHECKING, cast
 from enum import Enum, auto
 
-import pymbolic.primitives as pb
-from pymbolic.mapper import IdentityMapper
-
-from ..ast import PsAstNode, PsExpression, PsAssignment, PsStatement
+from ..ast.structural import PsAstNode, PsAssignment, PsStatement
+from ..ast.expressions import PsExpression
 from ..types import PsVectorType, deconstify
-from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
-from ..arrays import PsVectorArrayAccess
+from ..ast.expressions import (
+    PsVectorArrayAccess,
+    PsSymbolExpr,
+    PsConstantExpr,
+    PsBinOp,
+    PsAdd,
+    PsSub,
+    PsMul,
+    PsDiv,
+)
 from ..exceptions import PsInternalCompilerError
 
 if TYPE_CHECKING:
@@ -33,7 +39,7 @@ class VectorizationError(Exception):
 
 
 class VecTypeCtx:
-    def __init__(self):
+    def __init__(self) -> None:
         self._dtype: None | PsVectorType = None
 
     def get(self) -> PsVectorType | None:
@@ -51,68 +57,77 @@ class VecTypeCtx:
         self._dtype = None
 
 
-class MaterializeVectorIntrinsics(IdentityMapper):
+class MaterializeVectorIntrinsics:
     def __init__(self, platform: GenericVectorCpu):
         self._platform = platform
 
     def __call__(self, node: PsAstNode) -> PsAstNode:
+        return self.visit(node)
+
+    def visit(self, node: PsAstNode) -> PsAstNode:
         match node:
-            case PsExpression(expr):
-                # descend into expr
-                node.expression = self.rec(expr, VecTypeCtx())
-                return node
-            case PsAssignment(lhs, rhs) if isinstance(
-                lhs.expression, PsVectorArrayAccess
-            ):
+            case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorArrayAccess):
                 vc = VecTypeCtx()
-                vc.set(lhs.expression.dtype)
-                store_arg = self.rec(rhs.expression, vc)
-                return PsStatement(
-                    PsExpression(self._platform.vector_store(lhs.expression, store_arg))
-                )
-            case other:
-                other.children = (self(c) for c in other.children)
-        return node
-
-    def map_typed_variable(
-        self, tv: PsTypedVariable, vc: VecTypeCtx
-    ) -> PsTypedVariable:
-        if isinstance(tv.dtype, PsVectorType):
-            intrin_type = self._platform.type_intrinsic(tv.dtype)
-            vc.set(tv.dtype)
-            return PsTypedVariable(tv.name, intrin_type)
-        else:
-            return tv
-
-    def map_constant(self, c: PsTypedConstant, vc: VecTypeCtx) -> ExprOrConstant:
-        if isinstance(c.dtype, PsVectorType):
-            vc.set(c.dtype)
-            return self._platform.constant_vector(c)
-        else:
-            return c
-
-    def map_vector_array_access(
-        self, acc: PsVectorArrayAccess, vc: VecTypeCtx
-    ) -> pb.Expression:
-        vc.set(acc.dtype)
-        return self._platform.vector_load(acc)
-
-    def map_sum(self, expr: pb.Sum, vc: VecTypeCtx) -> pb.Expression:
-        args = [self.rec(arg, vc) for arg in expr.children]
-        vtype = vc.get()
-        if vtype is not None:
-            if len(args) != 2:
-                raise VectorizationError("Cannot vectorize non-binary sums")
-            return self._platform.op_intrinsic(IntrinsicOps.ADD, vtype, args)
-        else:
-            return expr
-
-    def map_product(self, expr: pb.Product, vc: VecTypeCtx) -> pb.Expression:
-        args = [self.rec(arg, vc) for arg in expr.children]
-        vtype = vc.get()
-        if vtype is not None:
-            if len(args) != 2:
-                raise VectorizationError("Cannot vectorize non-binary products")
-            return self._platform.op_intrinsic(IntrinsicOps.MUL, vtype, args)
-        else:
-            return expr
+                vc.set(lhs.dtype)
+                store_arg = self.visit_expr(rhs, vc)
+                return PsStatement(self._platform.vector_store(lhs, store_arg))
+            case PsExpression():
+                return self.visit_expr(node, VecTypeCtx())
+            case _:
+                node.children = [self(c) for c in node.children]
+                return node
+
+    def visit_expr(self, expr: PsExpression, vc: VecTypeCtx) -> PsExpression:
+        match expr:
+            case PsSymbolExpr(symb):
+                if isinstance(symb.dtype, PsVectorType):
+                    intrin_type = self._platform.type_intrinsic(symb.dtype)
+                    vc.set(symb.dtype)
+                    symb.dtype = intrin_type
+
+                return expr
+
+            case PsConstantExpr(c):
+                if isinstance(c.dtype, PsVectorType):
+                    vc.set(c.dtype)
+                    return self._platform.constant_vector(c)
+                else:
+                    return expr
+
+            case PsVectorArrayAccess():
+                vc.set(expr.dtype)
+                return self._platform.vector_load(expr)
+
+            case PsBinOp(op1, op2):
+                op1 = self.visit_expr(op1, vc)
+                op2 = self.visit_expr(op2, vc)
+
+                vtype = vc.get()
+                if vtype is not None:
+                    return self._platform.op_intrinsic(
+                        _intrin_op(expr), vtype, [op1, op2]
+                    )
+                else:
+                    return expr
+
+            case expr:
+                expr.children = [
+                    self.visit_expr(cast(PsExpression, c), vc) for c in expr.children
+                ]
+                if vc.get() is not None:
+                    raise VectorizationError(f"Don't know how to vectorize {expr}")
+                return expr
+
+
+def _intrin_op(expr: PsBinOp) -> IntrinsicOps:
+    match expr:
+        case PsAdd():
+            return IntrinsicOps.ADD
+        case PsSub():
+            return IntrinsicOps.SUB
+        case PsMul():
+            return IntrinsicOps.MUL
+        case PsDiv():
+            return IntrinsicOps.DIV
+        case _:
+            assert False
diff --git a/src/pystencils/backend/typed_expressions.py b/src/pystencils/backend/typed_expressions.py
deleted file mode 100644
index 15e4278fc..000000000
--- a/src/pystencils/backend/typed_expressions.py
+++ /dev/null
@@ -1,244 +0,0 @@
-from __future__ import annotations
-
-from typing import TypeAlias, Any
-from sys import intern
-
-import pymbolic.primitives as pb
-
-from .types import (
-    PsAbstractType,
-    PsNumericType,
-    constify,
-    PsTypeError,
-)
-
-
-class PsTypedVariable(pb.Variable):
-    init_arg_names: tuple[str, ...] = ("name", "dtype")
-
-    __match_args__ = ("name", "dtype")
-    mapper_method = intern("map_typed_variable")
-
-    def __init__(self, name: str, dtype: PsAbstractType):
-        super(PsTypedVariable, self).__init__(name)
-        self._dtype = dtype
-
-    def __getinitargs__(self):
-        return self.name, self._dtype
-
-    @property
-    def dtype(self) -> PsAbstractType:
-        return self._dtype
-
-
-class PsTypedConstant:
-    """Represents typed constants occuring in the pystencils AST.
-
-    Internal Representation of Constants
-    ------------------------------------
-
-    Each `PsNumericType` acts as a factory for the code generator's internal representation of that type's
-    constants. The `PsTypedConstant` class embedds these into the expression trees.
-    Upon construction, this class's constructor attempts to interpret the given value in the given data type
-    by passing it to the data type's factory, which in turn may throw an exception if the value's type does
-    not match.
-
-    Operations and Constant Folding
-    -------------------------------
-
-    The `PsTypedConstant` class overrides the basic arithmetic operations for use during a constant folding pass.
-    Their implementations are very strict regarding types: No implicit conversions take place, and both operands
-    must always have the exact same type.
-    The only exception to this rule are the values ``0``, ``1``, and ``-1``, which are promoted to `PsTypedConstant`
-    (pymbolic injects those at times).
-
-    A Note On Divisions
-    -------------------
-
-    The semantics of divisions in C and Python differ greatly.
-    Python has two division operators: ``/`` (``truediv``) and ``//`` (``floordiv``).
-    `truediv` is pure floating-point division, and so applied to floating-point numbers maps exactly to
-    floating-point division in C, but not when applied to integers.
-    ``floordiv`` has no C equivalent:
-    While ``floordiv`` performs euclidean division and always rounds its result
-    downward (``3 // 2 == 1``, and ``-3 // 2 = -2``),
-    the C ``/`` operator on integers always rounds toward zero (in C, ``-3 / 2 = -1``.)
-
-    The same applies to the ``%`` operator:
-    In Python, ``%`` computes the euclidean modulus (e.g. ``-3 % 2 = 1``),
-    while in C, ``%`` computes the remainder (e.g. ``-3 % 2 = -1``).
-    These two differ whenever negative numbers are involved.
-
-    Pymbolic provides ``Quotient`` to model Python's ``/``, ``FloorDiv`` to model ``//``,
-    and ``Remainder`` to model ``%``. The last one is a misnomer: it should instead be called ``Modulus``.
-
-    Since the pystencils backend has to accurately capture the behaviour of C,
-    the behaviour of ``/`` is overridden in `PsTypedConstant`.
-    In a floating-point context, it behaves as usual, while in an integer context,
-    it implements C-style integer division.
-    Similarily, ``%`` is only legal in integer contexts, where it implements the C-style remainder.
-    Usage of ``//`` and the pymbolic ``FloorDiv`` is illegal.
-    """
-
-    __match_args__ = ("value", "dtype")
-
-    @staticmethod
-    def try_create(value: Any, dtype: PsNumericType):
-        try:
-            return PsTypedConstant(value, dtype)
-        except PsTypeError:
-            return None
-
-    def __init__(self, value: Any, dtype: PsNumericType):
-        """Create a new `PsTypedConstant`.
-
-        The constructor of `PsTypedConstant` will first convert the given `dtype` to its const version.
-        The given `value` will then be interpreted as that data type. The constructor will fail with an
-        exception if that is not possible.
-        """
-        if not isinstance(dtype, PsNumericType):
-            raise ValueError(f"Cannot create constant of type {dtype}")
-
-        self._dtype = constify(dtype)
-        self._value = self._dtype.create_constant(value)
-
-    @property
-    def value(self) -> Any:
-        return self._value
-
-    @property
-    def dtype(self) -> PsNumericType:
-        return self._dtype
-
-    def __str__(self) -> str:
-        return self._dtype.create_literal(self._value)
-
-    def __repr__(self) -> str:
-        return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
-
-    def _fix(self, v: Any) -> PsTypedConstant:
-        """In binary operations, checks for type equality and, if necessary, promotes the values
-        ``0``, ``1`` and ``-1`` to `PsTypedConstant`."""
-        if not isinstance(v, PsTypedConstant):
-            return PsTypedConstant(v, self._dtype)
-        elif v._dtype != self._dtype:
-            raise PsTypeError(
-                f"Incompatible operand types in constant folding: {self._dtype} and {v._dtype}"
-            )
-        else:
-            return v
-
-    def _rfix(self, v: Any) -> PsTypedConstant:
-        """Same as `_fix`, but for use with the `r...` versions of the binary ops. Only changes the order of the
-        types in the exception string."""
-        if not isinstance(v, PsTypedConstant):
-            return PsTypedConstant(v, self._dtype)
-        elif v._dtype != self._dtype:
-            raise PsTypeError(
-                f"Incompatible operand types in constant folding: {v._dtype} and {self._dtype}"
-            )
-        else:
-            return v
-
-    def __add__(self, other: Any):
-        #   TODO: During freeze, expressions like `int + PsTypedConstant` can
-        #   occur. To cope with these, make the arithmetic operators of PsTypedConstant
-        #   purely symbolic? -> investigate
-        if isinstance(other, pb.Expression):  # let pymbolic handle this case
-            return NotImplemented
-
-        return PsTypedConstant(self._value + self._fix(other)._value, self._dtype)
-
-    def __radd__(self, other: Any):
-        return PsTypedConstant(self._rfix(other)._value + self._value, self._dtype)
-
-    def __mul__(self, other: Any):
-        if isinstance(other, pb.Expression):  # let pymbolic handle this case
-            return NotImplemented
-
-        return PsTypedConstant(self._value * self._fix(other)._value, self._dtype)
-
-    def __rmul__(self, other: Any):
-        return PsTypedConstant(self._rfix(other)._value * self._value, self._dtype)
-
-    def __sub__(self, other: Any):
-        if isinstance(other, pb.Expression):  # let pymbolic handle this case
-            return NotImplemented
-
-        return PsTypedConstant(self._value - self._fix(other)._value, self._dtype)
-
-    def __rsub__(self, other: Any):
-        return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
-
-    @staticmethod
-    def _divrem(dividend, divisor):
-        quotient = abs(dividend) // abs(divisor)
-        quotient = quotient if (dividend * divisor > 0) else (-quotient)
-        rem = abs(dividend) % abs(divisor)
-        rem = rem if dividend >= 0 else (-rem)
-        return quotient, rem
-
-    def __truediv__(self, other: Any):
-        if isinstance(other, pb.Expression):  # let pymbolic handle this case
-            return NotImplemented
-
-        if self._dtype.is_float():
-            return PsTypedConstant(self._value / self._fix(other)._value, self._dtype)
-        elif self._dtype.is_uint():
-            #   For unsigned integers, `//` does the correct thing
-            return PsTypedConstant(self._value // self._fix(other)._value, self._dtype)
-        elif self._dtype.is_sint():
-            dividend = self._value
-            divisor = self._fix(other)._value
-            quotient, _ = self._divrem(dividend, divisor)
-            return PsTypedConstant(quotient, self._dtype)
-        else:
-            return NotImplemented
-
-    def __rtruediv__(self, other: Any):
-        if self._dtype.is_float():
-            return PsTypedConstant(self._rfix(other)._value / self._value, self._dtype)
-        elif self._dtype.is_uint():
-            return PsTypedConstant(self._rfix(other)._value // self._value, self._dtype)
-        elif self._dtype.is_sint():
-            dividend = self._fix(other)._value
-            divisor = self._value
-            quotient, _ = self._divrem(dividend, divisor)
-            return PsTypedConstant(quotient, self._dtype)
-        else:
-            return NotImplemented
-
-    def __mod__(self, other: Any):
-        if isinstance(other, pb.Expression):  # let pymbolic handle this case
-            return NotImplemented
-
-        if self._dtype.is_uint():
-            return PsTypedConstant(self._value % self._fix(other)._value, self._dtype)
-        else:
-            dividend = self._value
-            divisor = self._fix(other)._value
-            _, rem = self._divrem(dividend, divisor)
-            return PsTypedConstant(rem, self._dtype)
-
-    def __neg__(self):
-        return PsTypedConstant(-self._value, self._dtype)
-
-    def __bool__(self):
-        return bool(self._value)
-
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, PsTypedConstant):
-            return False
-
-        return self._dtype == other._dtype and self._value == other._value
-
-    def __hash__(self) -> int:
-        return hash((self._value, self._dtype))
-
-
-pb.register_constant_class(PsTypedConstant)
-
-ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant
-"""Required since `PsTypedConstant` does not derive from ``pb.Expression``."""
-
-VarOrConstant: TypeAlias = PsTypedVariable | PsTypedConstant
diff --git a/src/pystencils/config.py b/src/pystencils/config.py
index ad5795a87..1e811f231 100644
--- a/src/pystencils/config.py
+++ b/src/pystencils/config.py
@@ -8,7 +8,7 @@ from .backend.jit import JitBase
 from .backend.exceptions import PsOptionsError
 from .backend.types import PsIntegerType, PsNumericType, PsIeeeFloatType
 
-from .backend.kernelcreation.defaults import Sympy as SpDefaults
+from .defaults import DEFAULTS
 
 
 @dataclass
@@ -61,7 +61,7 @@ class CreateKernelConfig:
 
     """Data Types"""
 
-    index_dtype: PsIntegerType = SpDefaults.index_dtype
+    index_dtype: PsIntegerType = DEFAULTS.index_dtype
     """Data type used for all index calculations."""
 
     default_dtype: PsNumericType = PsIeeeFloatType(64)
diff --git a/src/pystencils/backend/kernelcreation/defaults.py b/src/pystencils/defaults.py
similarity index 55%
rename from src/pystencils/backend/kernelcreation/defaults.py
rename to src/pystencils/defaults.py
index fe0e8ed4a..a031fd58c 100644
--- a/src/pystencils/backend/kernelcreation/defaults.py
+++ b/src/pystencils/defaults.py
@@ -1,31 +1,12 @@
-"""This module defines various default types, symbols and variables for use in pystencils kernels.
-
-On many occasions the SymPy frontend uses canonical symbols and types.
-With the pymbolic-based backend, these symbols have to exist in two
-variants; as `sp.Symbol` or `TypedSymbol`, and as `PsTypedVariable`s.
-Therefore, for conciseness, this module should collect and provide each of these symbols.
-
-We might furthermore consider making the defaults collection configurable.
-
-A possibly incomplete list of symbols and types that need to be defined:
-
- - The default indexing data type (currently loosely defined as `int`)
- - The default spatial iteration counters (currently defined by `LoopOverCoordinate`)
- - The names of the coordinate members of index lists (currently in `CreateKernelConfig.coordinate_names`)
- - The sparse iteration counter (doesn't even exist yet)
- - ...
-"""
-
 from typing import TypeVar, Generic, Callable
-from ..types import PsAbstractType, PsIeeeFloatType, PsSignedIntegerType, PsStructType
-from ..typed_expressions import PsTypedVariable
+from .backend.types import PsAbstractType, PsIeeeFloatType, PsSignedIntegerType, PsStructType
 
 from pystencils.sympyextensions.typed_sympy import TypedSymbol
 
 SymbolT = TypeVar("SymbolT")
 
 
-class PsDefaults(Generic[SymbolT]):
+class GenericDefaults(Generic[SymbolT]):
     def __init__(self, symcreate: Callable[[str, PsAbstractType], SymbolT]):
         self.numeric_dtype = PsIeeeFloatType(64)
         """Default data type for numerical computations"""
@@ -60,5 +41,5 @@ class PsDefaults(Generic[SymbolT]):
         """Default sparse iteration counter."""
 
 
-Sympy = PsDefaults[TypedSymbol](TypedSymbol)
-Pymbolic = PsDefaults[PsTypedVariable](PsTypedVariable)
+DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol)
+"""Default names and symbols used throughout code generation"""
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 5c1910d2a..293578982 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -12,7 +12,7 @@ from .backend.kernelcreation.iteration_space import (
     create_full_iteration_space,
 )
 
-from .backend.ast.collectors import collect_required_headers
+from .backend.ast.analysis import collect_required_headers
 from .backend.transformations import EraseAnonymousStructTypes
 
 from .enums import Target
diff --git a/tests/nbackend/kernelcreation/platform/test_basic_cpu.py b/tests/nbackend/kernelcreation/platform/test_basic_cpu.py
index 028ffc122..e69a07c97 100644
--- a/tests/nbackend/kernelcreation/platform/test_basic_cpu.py
+++ b/tests/nbackend/kernelcreation/platform/test_basic_cpu.py
@@ -7,7 +7,9 @@ from pystencils.backend.kernelcreation import (
     FullIterationSpace
 )
 
-from pystencils.backend.ast import PsBlock, PsLoop, PsComment, dfs_preorder
+from pystencils.backend.ast.structural import PsBlock, PsLoop, PsComment
+from pystencils.backend.ast.expressions import PsExpression
+from pystencils.backend.ast import dfs_preorder
 
 from pystencils.backend.platforms import GenericCpu
 
@@ -27,7 +29,7 @@ def test_loop_nest(layout):
     loops = dfs_preorder(loop_nest, lambda n: isinstance(n, PsLoop))
     for loop, dim in zip(loops, ispace.dimensions, strict=True):
         assert isinstance(loop, PsLoop)
-        assert loop.start.expression == dim.start
-        assert loop.stop.expression == dim.stop
-        assert loop.step.expression == dim.step
-        assert loop.counter.expression == dim.counter
+        assert loop.start.structurally_equal(dim.start)
+        assert loop.stop.structurally_equal(dim.stop)
+        assert loop.step.structurally_equal(dim.step)
+        assert loop.counter.structurally_equal(PsExpression.make(dim.counter))
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index acecc6503..c8f39ffc8 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -1,17 +1,16 @@
 import sympy as sp
-import pymbolic.primitives as pb
 
 from pystencils import Assignment, fields
 
-from pystencils.backend.ast import (
+from pystencils.backend.ast.structural import (
     PsAssignment,
     PsDeclaration,
+)
+from pystencils.backend.ast.expressions import (
     PsExpression,
-    PsSymbolExpr,
-    PsLvalueExpr,
+    PsArrayAccess
 )
-from pystencils.backend.typed_expressions import PsTypedConstant, PsTypedVariable
-from pystencils.backend.arrays import PsArrayAccess
+from pystencils.backend.constants import PsConstant
 from pystencils.backend.kernelcreation import (
     KernelCreationContext,
     FreezeExpressions,
@@ -28,19 +27,25 @@ def test_freeze_simple():
 
     fasm = freeze(asm)
 
-    pb_x, pb_y, pb_z = pb.variables("x y z")
+    x2 = PsExpression.make(ctx.get_symbol("x"))
+    y2 = PsExpression.make(ctx.get_symbol("y"))
+    z2 = PsExpression.make(ctx.get_symbol("z"))
+
+    two = PsExpression.make(PsConstant(2))
 
-    assert fasm == PsDeclaration(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
-    assert fasm != PsAssignment(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
+    should = PsDeclaration(z2, y2 + two * x2)
+
+    assert fasm.structurally_equal(should)
+    assert not fasm.structurally_equal(PsAssignment(z2, two * x2 + y2))
 
 
 def test_freeze_fields():
     ctx = KernelCreationContext()
 
-    zero = PsTypedConstant(0, ctx.index_dtype)
-    forty_two = PsTypedConstant(42, ctx.index_dtype)
-    one = PsTypedConstant(1, ctx.index_dtype)
-    counter = PsTypedVariable("ctr", ctx.index_dtype)
+    zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
+    forty_two = PsExpression.make(PsConstant(42, ctx.index_dtype))
+    one = PsExpression.make(PsConstant(1, ctx.index_dtype))
+    counter = ctx.get_symbol("ctr", ctx.index_dtype)
     ispace = FullIterationSpace(
         ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)]
     )
@@ -56,9 +61,11 @@ def test_freeze_fields():
 
     fasm = freeze(asm)
 
-    lhs = PsArrayAccess(f_arr.base_pointer, pb.Sum((counter * f_arr.strides[0], zero)))
-    rhs = PsArrayAccess(g_arr.base_pointer, pb.Sum((counter * g_arr.strides[0], zero)))
+    zero = PsExpression.make(PsConstant(0))
+
+    lhs = PsArrayAccess(f_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) + zero * one)
+    rhs = PsArrayAccess(g_arr.base_pointer, (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) + zero * one)
 
-    should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
+    should = PsAssignment(lhs, rhs)
 
-    assert fasm == should
+    assert fasm.structurally_equal(should)
diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py
index 477684ff3..9413b9def 100644
--- a/tests/nbackend/kernelcreation/test_iteration_space.py
+++ b/tests/nbackend/kernelcreation/test_iteration_space.py
@@ -1,65 +1,72 @@
 import pytest
 
-import pymbolic.primitives as pb
+from pystencils.defaults import DEFAULTS
 from pystencils.field import Field
 from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type
 
-from pystencils.backend.kernelcreation import (
-    KernelCreationContext,
-    FullIterationSpace
-)
+from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace
 
+from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression
 from pystencils.backend.kernelcreation.typification import TypificationError
-from pystencils.backend.kernelcreation.defaults import Pymbolic as PbDefaults
-
-from pystencils.backend.typed_expressions import PsTypedConstant
 
 
 def test_loop_order():
     ctx = KernelCreationContext()
-    ctr_symbols = PbDefaults.spatial_counters
+    ctr_symbols = [
+        ctx.get_symbol(sname, ctx.index_dtype)
+        for sname in DEFAULTS.spatial_counter_names
+    ]
 
     #   FZYX Order
-    archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout='fzyx')
+    archetype_field = Field.create_generic(
+        "fzyx_field", spatial_dimensions=3, layout="fzyx"
+    )
     ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0)
 
     for dim, ctr in zip(ispace.dimensions, ctr_symbols[::-1]):
         assert dim.counter == ctr
 
     #   ZYXF Order
-    archetype_field = Field.create_generic("zyxf_field", spatial_dimensions=3, layout='zyxf')
+    archetype_field = Field.create_generic(
+        "zyxf_field", spatial_dimensions=3, layout="zyxf"
+    )
     ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0)
 
     for dim, ctr in zip(ispace.dimensions, ctr_symbols[::-1]):
         assert dim.counter == ctr
 
     #   C Order
-    archetype_field = Field.create_generic("c_field", spatial_dimensions=3, layout='c')
+    archetype_field = Field.create_generic("c_field", spatial_dimensions=3, layout="c")
     ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0)
 
     for dim, ctr in zip(ispace.dimensions, ctr_symbols):
         assert dim.counter == ctr
 
     #   Fortran Order
-    archetype_field = Field.create_generic("fortran_field", spatial_dimensions=3, layout='f')
+    archetype_field = Field.create_generic(
+        "fortran_field", spatial_dimensions=3, layout="f"
+    )
     ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0)
 
     for dim, ctr in zip(ispace.dimensions, ctr_symbols[::-1]):
         assert dim.counter == ctr
 
     #   Scrambled Layout
-    archetype_field = Field.create_generic("scrambled_field", spatial_dimensions=3, layout=(2, 0, 1))
+    archetype_field = Field.create_generic(
+        "scrambled_field", spatial_dimensions=3, layout=(2, 0, 1)
+    )
     ispace = FullIterationSpace.create_with_ghost_layers(ctx, archetype_field, 0)
 
-    for dim, ctr in zip(ispace.dimensions, [ctr_symbols[2], ctr_symbols[0], ctr_symbols[1]]):
+    for dim, ctr in zip(
+        ispace.dimensions, [ctr_symbols[2], ctr_symbols[0], ctr_symbols[1]]
+    ):
         assert dim.counter == ctr
 
 
 def test_slices():
     ctx = KernelCreationContext()
-    ctr_symbols = PbDefaults.spatial_counters
 
-    archetype_field = Field.create_generic("f", spatial_dimensions=3, layout='fzyx')
+    archetype_field = Field.create_generic("f", spatial_dimensions=3, layout="fzyx")
     ctx.add_field(archetype_field)
 
     islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, -1))
@@ -70,19 +77,31 @@ def test_slices():
     dims = ispace.dimensions[::-1]
 
     for sl, size, dim in zip(islice, archetype_arr.shape, dims):
-        assert isinstance(dim.start, PsTypedConstant) and dim.start.value == sl.start
-        assert isinstance(dim.step, PsTypedConstant) and dim.step.value == sl.step
-
-    assert isinstance(dims[0].stop, pb.Sum) and archetype_arr.shape[0] in dims[0].stop.children
-    assert isinstance(dims[1].stop, pb.Sum) and archetype_arr.shape[1] in dims[1].stop.children
-    assert dims[2].stop == archetype_arr.shape[2]
+        assert (
+            isinstance(dim.start, PsConstantExpr)
+            and dim.start.constant.value == sl.start
+        )
+        assert (
+            isinstance(dim.step, PsConstantExpr) and dim.step.constant.value == sl.step
+        )
+
+    assert isinstance(dims[0].stop, PsAdd) and any(
+        op.structurally_equal(PsExpression.make(archetype_arr.shape[0]))
+        for op in dims[0].stop.children
+    )
+    
+    assert isinstance(dims[1].stop, PsAdd) and any(
+        op.structurally_equal(PsExpression.make(archetype_arr.shape[1]))
+        for op in dims[1].stop.children
+    )
+    
+    assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2]))
 
 
 def test_invalid_slices():
     ctx = KernelCreationContext()
-    ctr_symbols = PbDefaults.spatial_counters
 
-    archetype_field = Field.create_generic("f", spatial_dimensions=1, layout='fzyx')
+    archetype_field = Field.create_generic("f", spatial_dimensions=1, layout="fzyx")
     ctx.add_field(archetype_field)
 
     islice = (slice(1, -1, 0.5),)
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 6e24a876f..9b9aba35d 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -1,14 +1,13 @@
 import pytest
 import sympy as sp
 import numpy as np
-import pymbolic.primitives as pb
 
 from pystencils import Assignment, TypedSymbol, Field, FieldType
 
-from pystencils.backend.ast import PsDeclaration
+from pystencils.backend.ast.structural import PsDeclaration
+from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
 from pystencils.backend.types import constify
-from pystencils.backend.types.quick import *
-from pystencils.backend.typed_expressions import PsTypedConstant, PsTypedVariable
+from pystencils.backend.types.quick import Fp, make_numeric_type
 from pystencils.backend.kernelcreation.context import KernelCreationContext
 from pystencils.backend.kernelcreation.freeze import FreezeExpressions
 from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
@@ -29,19 +28,20 @@ def test_typify_simple():
 
     def check(expr):
         match expr:
-            case PsTypedConstant(value, dtype):
-                assert value == 2
-                assert dtype == constify(ctx.default_dtype)
-            case PsTypedVariable(name, dtype):
-                assert name in "xyz"
-                assert dtype == ctx.default_dtype
-            case pb.Sum(cs) | pb.Product(cs):
-                [check(c) for c in cs]
+            case PsConstantExpr(cs):
+                assert cs.value == 2
+                assert cs.dtype == constify(ctx.default_dtype)
+            case PsSymbolExpr(symb):
+                assert symb.name in "xyz"
+                assert symb.dtype == ctx.default_dtype
+            case PsBinOp(op1, op2):
+                check(op1)
+                check(op2)
             case _:
                 pytest.fail(f"Unexpected expression: {expr}")
 
-    check(fasm.lhs.expression)
-    check(fasm.rhs.expression)
+    check(fasm.lhs)
+    check(fasm.rhs)
 
 
 def test_typify_structs():
@@ -76,18 +76,19 @@ def test_contextual_typing():
 
     def check(expr):
         match expr:
-            case PsTypedConstant(value, dtype):
-                assert value in (2, 3, -4)
-                assert dtype == constify(ctx.default_dtype)
-            case PsTypedVariable(name, dtype):
-                assert name in "xyz"
-                assert dtype == ctx.default_dtype
-            case pb.Sum(cs) | pb.Product(cs):
-                [check(c) for c in cs]
+            case PsConstantExpr(cs):
+                assert cs.value in (2, 3, -4)
+                assert cs.dtype == constify(ctx.default_dtype)
+            case PsSymbolExpr(symb):
+                assert symb.name in "xyz"
+                assert symb.dtype == ctx.default_dtype
+            case PsBinOp(op1, op2):
+                check(op1)
+                check(op2)
             case _:
                 pytest.fail(f"Unexpected expression: {expr}")
 
-    check(expr.expression)
+    check(expr)
 
 
 def test_erronous_typing():
diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py
deleted file mode 100644
index 36a51b44b..000000000
--- a/tests/nbackend/test_basic_printing.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from pystencils import Target
-
-from pystencils.backend.ast import *
-from pystencils.backend.typed_expressions import *
-from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
-from pystencils.backend.types.quick import *
-from pystencils.backend.emission import CAstPrinter
-
-
-def test_basic_kernel():
-
-    u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, ))
-    u_size = u_arr.shape[0]
-    u_base = PsArrayBasePointer("u_data", u_arr)
-
-    loop_ctr = PsTypedVariable("ctr", UInt(32))
-    one = PsTypedConstant(1, SInt(32))
-
-    update = PsAssignment(
-        PsLvalueExpr(PsArrayAccess(u_base, loop_ctr)),
-        PsExpression(PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one)),
-    )
-
-    loop = PsLoop(
-        PsSymbolExpr(loop_ctr),
-        PsExpression(one),
-        PsExpression(u_size - one),
-        PsExpression(one),
-        PsBlock([update])
-    )
-
-    func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
-
-    printer = CAstPrinter()
-    code = printer.print(func)
-
-    paramlist = func.get_parameters().params
-    params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist)
-
-    assert code.find("(" + params_str + ")") >= 0
-    
-    assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr + -1];") >= 0
-
diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py
new file mode 100644
index 000000000..be690db7c
--- /dev/null
+++ b/tests/nbackend/test_code_printing.py
@@ -0,0 +1,77 @@
+from pystencils import Target
+
+from pystencils.backend.ast.expressions import PsExpression, PsArrayAccess
+from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock
+from pystencils.backend.ast.kernelfunction import PsKernelFunction
+from pystencils.backend.symbols import PsSymbol
+from pystencils.backend.constants import PsConstant
+from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer
+from pystencils.backend.types.quick import Fp, SInt, UInt
+from pystencils.backend.emission import CAstPrinter
+
+
+def test_basic_kernel():
+
+    u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, ))
+    u_size = PsExpression.make(u_arr.shape[0])
+    u_base = PsArrayBasePointer("u_data", u_arr)
+
+    loop_ctr = PsExpression.make(PsSymbol("ctr", UInt(32)))
+    one = PsExpression.make(PsConstant(1, SInt(32)))
+
+    update = PsAssignment(
+        PsArrayAccess(u_base, loop_ctr),
+        PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one),
+    )
+
+    loop = PsLoop(
+        loop_ctr,
+        one,
+        u_size - one,
+        one,
+        PsBlock([update])
+    )
+
+    func = PsKernelFunction(PsBlock([loop]), Target.CPU, "kernel", set())
+
+    printer = CAstPrinter()
+    code = printer(func)
+
+    paramlist = func.get_parameters().params
+    params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist)
+
+    assert code.find("(" + params_str + ")") >= 0
+    assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1];") >= 0
+
+
+def test_arithmetic_precedence():
+    (a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"]
+    cprint = CAstPrinter()
+
+    expr = (a + b) + (c + d)
+    code = cprint(expr)
+    assert code == "a + b + (c + d)"
+
+    expr = ((a + b) + c) + d
+    code = cprint(expr)
+    assert code == "a + b + c + d"
+
+    expr = a + (b + (c + d))
+    code = cprint(expr)
+    assert code == "a + (b + (c + d))"
+
+    expr = a - (b - c) - d
+    code = cprint(expr)
+    assert code == "a - (b - c) - d"
+
+    expr = a + b * (c + d * (e + f))
+    code = cprint(expr)
+    assert code == "a + b * (c + d * (e + f))"
+
+    expr = (-a) + b + (-c) + -(e + f)
+    code = cprint(expr)
+    assert code == "-a + b + -c + -(e + f)"
+
+    expr = (a / b) + (c / (d + e) * f)
+    code = cprint(expr)
+    assert code == "a / b + c / (d + e) * f"
diff --git a/tests/nbackend/test_constant_folding.py b/tests/nbackend/test_constant_folding.py
index 686a4346f..b71b36ea8 100644
--- a/tests/nbackend/test_constant_folding.py
+++ b/tests/nbackend/test_constant_folding.py
@@ -1,67 +1,26 @@
-import pytest
+#   TODO: Reimplement for constant folder
+# import pytest
 
-import pymbolic.primitives as pb
-from pymbolic.mapper.constant_folder import ConstantFoldingMapper
+# from pystencils.backend.types.quick import *
+# from pystencils.backend.constants import PsConstant
 
-from pystencils.backend.types.quick import *
-from pystencils.backend.typed_expressions import PsTypedConstant
 
+# @pytest.mark.parametrize("width", (8, 16, 32, 64))
+# def test_constant_folding_int(width):
+#     folder = ConstantFoldingMapper()
 
-@pytest.mark.parametrize("width", (8, 16, 32, 64))
-def test_constant_folding_int(width):
-    folder = ConstantFoldingMapper()
+#     expr = pb.Sum(
+#         (
+#             PsTypedConstant(13, UInt(width)),
+#             PsTypedConstant(5, UInt(width)),
+#             PsTypedConstant(3, UInt(width)),
+#         )
+#     )
 
-    expr = pb.Sum(
-        (
-            PsTypedConstant(13, UInt(width)),
-            PsTypedConstant(5, UInt(width)),
-            PsTypedConstant(3, UInt(width)),
-        )
-    )
+#     assert folder(expr) == PsTypedConstant(21, UInt(width))
 
-    assert folder(expr) == PsTypedConstant(21, UInt(width))
+#     expr = pb.Product(
+#         (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width)))
+#     ) - PsTypedConstant(12, SInt(width))
 
-    expr = pb.Product(
-        (PsTypedConstant(-1, SInt(width)), PsTypedConstant(41, SInt(width)))
-    ) - PsTypedConstant(12, SInt(width))
-
-    assert folder(expr) == PsTypedConstant(-53, SInt(width))
-
-@pytest.mark.xfail(reason="Current constant folder does not handle products")
-@pytest.mark.parametrize("width", (8, 16, 32, 64))
-def test_constant_folding_product(width):
-    """
-    The pymbolic constant folder shows inconsistent behaviour when folding products.
-    This test both describes the required behaviour and serves as a reminder to fix it.
-    """
-    folder = ConstantFoldingMapper()
-
-    expr = pb.Product(
-        (
-            PsTypedConstant(2, SInt(width)),
-            PsTypedConstant(-3, SInt(width)),
-            PsTypedConstant(4, SInt(width))
-        )
-    )
-
-    assert folder(expr) == PsTypedConstant(-24, SInt(width))
-
-
-@pytest.mark.xfail(reason="Current constant folder does not handle divisions")
-@pytest.mark.parametrize("width", (32, 64))
-def test_constant_folding_float(width):
-    """The pymbolic constant folder does not fold quotients. This test serves as a reminder
-    to consider that behaviour"""
-    folder = ConstantFoldingMapper()
-
-    expr = pb.Quotient(
-        PsTypedConstant(14.0, Fp(width)),
-        pb.Sum(
-            (
-                PsTypedConstant(2.5, Fp(width)),
-                PsTypedConstant(4.5, Fp(width)),
-            )
-        ),
-    )
-
-    assert folder(expr) == PsTypedConstant(7.0, Fp(width))
+#     assert folder(expr) == PsTypedConstant(-53, SInt(width))
diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py
index 865b698c4..1163c40ce 100644
--- a/tests/nbackend/test_cpujit.py
+++ b/tests/nbackend/test_cpujit.py
@@ -2,15 +2,22 @@ import pytest
 
 from pystencils import Target
 
-from pystencils.backend.ast import *
-from pystencils.backend.constraints import PsKernelConstraint
-from pystencils.backend.typed_expressions import *
-from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
-from pystencils.backend.types.quick import *
+# from pystencils.backend.constraints import PsKernelParamsConstraint
+from pystencils.backend.symbols import PsSymbol
+from pystencils.backend.constants import PsConstant
+from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer
+
+from pystencils.backend.ast.expressions import PsArrayAccess, PsExpression
+from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop
+from pystencils.backend.ast.kernelfunction import PsKernelFunction
+
+from pystencils.backend.types.quick import SInt, Fp
 from pystencils.backend.jit import LegacyCpuJit
 
 import numpy as np
 
+
+@pytest.mark.xfail(reason="Fails until constraints are reimplemented")
 def test_pairwise_addition():
     idx_type = SInt(64)
 
@@ -20,33 +27,33 @@ def test_pairwise_addition():
     u_data = PsArrayBasePointer("u_data", u)
     v_data = PsArrayBasePointer("v_data", v)
 
-    loop_ctr = PsTypedVariable("ctr", idx_type)
+    loop_ctr = PsExpression.make(PsSymbol("ctr", idx_type))
     
-    zero = PsTypedConstant(0, idx_type)
-    one = PsTypedConstant(1, idx_type)
-    two = PsTypedConstant(2, idx_type)
+    zero = PsExpression.make(PsConstant(0, idx_type))
+    one = PsExpression.make(PsConstant(1, idx_type))
+    two = PsExpression.make(PsConstant(2, idx_type))
 
     update = PsAssignment(
-        PsLvalueExpr(PsArrayAccess(v_data, loop_ctr)),
-        PsExpression(PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one))
+        PsArrayAccess(v_data, loop_ctr),
+        PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one)
     )
 
     loop = PsLoop(
-        PsSymbolExpr(loop_ctr),
-        PsExpression(zero),
-        PsExpression(v.shape[0]),
-        PsExpression(one),
+        loop_ctr,
+        zero,
+        PsExpression.make(v.shape[0]),
+        one,
         PsBlock([update])
     )
 
-    func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
+    func = PsKernelFunction(PsBlock([loop]), Target.CPU, "kernel", set())
 
-    sizes_constraint = PsKernelConstraint(
-        u.shape[0].eq(2 * v.shape[0]),
-        "Array `u` must have twice the length of array `v`"
-    )
+    # sizes_constraint = PsKernelParamsConstraint(
+    #     u.shape[0].eq(2 * v.shape[0]),
+    #     "Array `u` must have twice the length of array `v`"
+    # )
 
-    func.add_constraints(sizes_constraint)
+    # func.add_constraints(sizes_constraint)
 
     jit = LegacyCpuJit()
     kernel = jit.compile(func)
diff --git a/tests/nbackend/test_expressions.py b/tests/nbackend/test_expressions.py
deleted file mode 100644
index 65a0a4e1e..000000000
--- a/tests/nbackend/test_expressions.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from pystencils.backend.typed_expressions import PsTypedVariable
-from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayShapeVar, PsArrayStrideVar
-from pystencils.backend.types.quick import *
-
-def test_variable_equality():
-    var1 = PsTypedVariable("x", Fp(32))
-    var2 = PsTypedVariable("x", Fp(32))
-    assert var1 == var2
-
-    shape = (..., ..., ...)
-    strides = (..., ..., ...)
-
-    arr = PsLinearizedArray("arr", Fp(64), shape, strides)
-    bp1 = PsArrayBasePointer("arr_data", arr)
-    bp2 = PsArrayBasePointer("arr_data", arr)
-    assert bp1 == bp2
-
-    arr1 = PsLinearizedArray("arr", Fp(64), shape, strides)
-    bp1 = PsArrayBasePointer("arr_data", arr1)
-
-    arr2 = PsLinearizedArray("arr", Fp(64), shape, strides)
-    bp2 = PsArrayBasePointer("arr_data", arr2)
-    assert bp1 == bp2
-
-    for v1, v2 in zip(arr1.shape, arr2.shape):
-        assert v1 == v2
-
-    for v1, v2 in zip(arr1.strides, arr2.strides):
-        assert v1 == v2
-
-
-def test_variable_inequality():
-    shape = (..., ..., ...)
-    strides = (..., ..., ...)
-
-    var1 = PsTypedVariable("x", Fp(32))
-    var2 = PsTypedVariable("x", Fp(64))
-    assert var1 != var2
-
-    var1 = PsTypedVariable("x", Fp(32, True))
-    var2 = PsTypedVariable("x", Fp(32, False))
-    assert var1 != var2
-
-    #   Arrays 
-    arr1 = PsLinearizedArray("arr", Fp(64), shape, strides)
-    bp1 = PsArrayBasePointer("arr_data", arr1)
-
-    arr2 = PsLinearizedArray("arr", Fp(32), shape, strides)
-    bp2 = PsArrayBasePointer("arr_data", arr2)
-    assert bp1 != bp2
-
diff --git a/tests/nbackend/types/test_constants.py b/tests/nbackend/types/test_constants.py
index a5399c66e..4353973f5 100644
--- a/tests/nbackend/types/test_constants.py
+++ b/tests/nbackend/types/test_constants.py
@@ -1,79 +1,80 @@
-import pytest
+# import pytest
 
-from pystencils.backend.types.quick import *
-from pystencils.backend.types import PsTypeError
-from pystencils.backend.typed_expressions import PsTypedConstant
+# TODO: Re-implement for constant folder
+# from pystencils.backend.types.quick import *
+# from pystencils.backend.types import PsTypeError
+# from pystencils.backend.typed_expressions import PsTypedConstant
 
 
-@pytest.mark.parametrize("width", (8, 16, 32, 64))
-def test_integer_constants(width):
-    dtype = SInt(width)
-    a = PsTypedConstant(42, dtype)
-    b = PsTypedConstant(2, dtype)
+# @pytest.mark.parametrize("width", (8, 16, 32, 64))
+# def test_integer_constants(width):
+#     dtype = SInt(width)
+#     a = PsTypedConstant(42, dtype)
+#     b = PsTypedConstant(2, dtype)
 
-    assert a + b == PsTypedConstant(44, dtype)
-    assert a - b == PsTypedConstant(40, dtype)
-    assert a * b == PsTypedConstant(84, dtype)
+#     assert a + b == PsTypedConstant(44, dtype)
+#     assert a - b == PsTypedConstant(40, dtype)
+#     assert a * b == PsTypedConstant(84, dtype)
 
-    assert a - b != PsTypedConstant(-12, dtype)
+#     assert a - b != PsTypedConstant(-12, dtype)
 
-    #   Typed constants only compare to themselves
-    assert a + b != 44
+#     #   Typed constants only compare to themselves
+#     assert a + b != 44
 
 
-@pytest.mark.parametrize("width", (32, 64))
-def test_float_constants(width):
-    a = PsTypedConstant(32.0, Fp(width))
-    b = PsTypedConstant(0.5, Fp(width))
-    c = PsTypedConstant(2.0, Fp(width))
+# @pytest.mark.parametrize("width", (32, 64))
+# def test_float_constants(width):
+#     a = PsTypedConstant(32.0, Fp(width))
+#     b = PsTypedConstant(0.5, Fp(width))
+#     c = PsTypedConstant(2.0, Fp(width))
 
-    assert a + b == PsTypedConstant(32.5, Fp(width))
-    assert a * b == PsTypedConstant(16.0, Fp(width))
-    assert a - b == PsTypedConstant(31.5, Fp(width))
-    assert a / c == PsTypedConstant(16.0, Fp(width))
+#     assert a + b == PsTypedConstant(32.5, Fp(width))
+#     assert a * b == PsTypedConstant(16.0, Fp(width))
+#     assert a - b == PsTypedConstant(31.5, Fp(width))
+#     assert a / c == PsTypedConstant(16.0, Fp(width))
 
 
-def test_illegal_ops():
-    #   Cannot interpret negative numbers as unsigned types
-    with pytest.raises(PsTypeError):
-        _ = PsTypedConstant(-3, UInt(32))
+# def test_illegal_ops():
+#     #   Cannot interpret negative numbers as unsigned types
+#     with pytest.raises(PsTypeError):
+#         _ = PsTypedConstant(-3, UInt(32))
 
-    #   Mixed ops are illegal
-    with pytest.raises(PsTypeError):
-        _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32))
+#     #   Mixed ops are illegal
+#     with pytest.raises(PsTypeError):
+#         _ = PsTypedConstant(32.0, Fp(32)) + PsTypedConstant(2, UInt(32))
 
-    with pytest.raises(PsTypeError):
-        _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32))
+#     with pytest.raises(PsTypeError):
+#         _ = PsTypedConstant(32.0, Fp(32)) - PsTypedConstant(2, UInt(32))
 
-    with pytest.raises(PsTypeError):
-        _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32))
+#     with pytest.raises(PsTypeError):
+#         _ = PsTypedConstant(32.0, Fp(32)) * PsTypedConstant(2, UInt(32))
 
-    with pytest.raises(PsTypeError):
-        _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32))
+#     with pytest.raises(PsTypeError):
+#         _ = PsTypedConstant(32.0, Fp(32)) / PsTypedConstant(2, UInt(32))
 
 
-@pytest.mark.parametrize("width", (8, 16, 32, 64))
-def test_unsigned_integer_division(width):
-    a = PsTypedConstant(8, UInt(width))
-    b = PsTypedConstant(3, UInt(width))
+# @pytest.mark.parametrize("width", (8, 16, 32, 64))
+# def test_unsigned_integer_division(width):
+#     a = PsTypedConstant(8, UInt(width))
+#     b = PsTypedConstant(3, UInt(width))
 
-    assert a / b == PsTypedConstant(2, UInt(width))
-    assert a % b == PsTypedConstant(2, UInt(width))
+#     assert a / b == PsTypedConstant(2, UInt(width))
+#     assert a % b == PsTypedConstant(2, UInt(width))
 
 
-@pytest.mark.parametrize("width", (8, 16, 32, 64))
-def test_signed_integer_division(width):
-    five = PsTypedConstant(5, SInt(width))
-    two = PsTypedConstant(2, SInt(width))
+# @pytest.mark.parametrize("width", (8, 16, 32, 64))
+# def test_signed_integer_division(width):
+#     five = PsTypedConstant(5, SInt(width))
+#     two = PsTypedConstant(2, SInt(width))
 
-    assert five / two == PsTypedConstant(2, SInt(width))
-    assert five % two == PsTypedConstant(1, SInt(width))
+#     assert five / two == PsTypedConstant(2, SInt(width))
+#     assert five % two == PsTypedConstant(1, SInt(width))
 
-    assert (- five) / two == PsTypedConstant(-2, SInt(width))
-    assert (- five) % two == PsTypedConstant(-1, SInt(width))
+#     assert (- five) / two == PsTypedConstant(-2, SInt(width))
+#     assert (- five) % two == PsTypedConstant(-1, SInt(width))
 
-    assert five / (- two) == PsTypedConstant(-2, SInt(width))
-    assert five % (- two) == PsTypedConstant(1, SInt(width))
+#     assert five / (- two) == PsTypedConstant(-2, SInt(width))
+#     assert five % (- two) == PsTypedConstant(1, SInt(width))
 
-    assert (- five) / (- two) == PsTypedConstant(2, SInt(width))
-    assert (- five) % (- two) == PsTypedConstant(-1, SInt(width))
+#     assert (- five) / (- two) == PsTypedConstant(2, SInt(width))
+#     assert (- five) % (- two) == PsTypedConstant(-1, SInt(width))
-- 
GitLab