From d61a7c0cb93e8d10ab54f25d349e08126285cd86 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 28 Jan 2024 17:30:57 +0100
Subject: [PATCH] codegen for sparse kernels

---
 src/pystencils/nbackend/ast/collectors.py     |   8 +-
 src/pystencils/nbackend/exceptions.py         |   4 +
 .../nbackend/kernelcreation/analysis.py       |   7 +-
 .../nbackend/kernelcreation/context.py        |  28 +++--
 .../nbackend/kernelcreation/defaults.py       |  18 ++-
 .../nbackend/kernelcreation/freeze.py         |   6 +-
 .../kernelcreation/iteration_space.py         | 115 +++++++++++++++---
 .../nbackend/kernelcreation/kernelcreation.py |  10 +-
 .../nbackend/kernelcreation/options.py        |  17 ++-
 .../kernelcreation/platform/__init__.py       |   4 +-
 .../kernelcreation/platform/basic_cpu.py      |  48 ++++++--
 .../kernelcreation/platform/platform.py       |   4 +-
 .../nbackend/kernelcreation/typification.py   |  23 ++--
 src/pystencils/nbackend/types/__init__.py     |   2 +
 src/pystencils/nbackend/types/basic_types.py  | 115 +++++++++++++++++-
 src/pystencils/nbackend/types/parsing.py      |  17 +++
 src/pystencils/nbackend/types/quick.py        |   8 +-
 17 files changed, 360 insertions(+), 74 deletions(-)

diff --git a/src/pystencils/nbackend/ast/collectors.py b/src/pystencils/nbackend/ast/collectors.py
index 3c995ccf6..c94d0b18b 100644
--- a/src/pystencils/nbackend/ast/collectors.py
+++ b/src/pystencils/nbackend/ast/collectors.py
@@ -30,6 +30,8 @@ class UndefinedVariablesCollector:
     def __call__(self, node: PsAstNode) -> set[PsTypedVariable]:
         """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
 
+        undefined_vars: set[PsTypedVariable] = set()
+
         match node:
             case PsKernelFunction(block):
                 return self(block)
@@ -46,10 +48,12 @@ class UndefinedVariablesCollector:
                 return cast(set[PsTypedVariable], variables)
 
             case PsAssignment(lhs, rhs):
-                return self(lhs) | self(rhs)
+                undefined_vars = self(lhs) | self(rhs)
+                if isinstance(lhs.expression, PsTypedVariable):
+                    undefined_vars.remove(lhs.expression)
+                return undefined_vars
 
             case PsBlock(statements):
-                undefined_vars: set[PsTypedVariable] = set()
                 for stmt in statements[::-1]:
                     undefined_vars -= self.declared_variables(stmt)
                     undefined_vars |= self(stmt)
diff --git a/src/pystencils/nbackend/exceptions.py b/src/pystencils/nbackend/exceptions.py
index 6fe83b3f9..29434d8ef 100644
--- a/src/pystencils/nbackend/exceptions.py
+++ b/src/pystencils/nbackend/exceptions.py
@@ -10,5 +10,9 @@ class PsInputError(Exception):
     pass
 
 
+class KernelConstraintsError(Exception):
+    pass
+
+
 class PsMalformedAstException(Exception):
     pass
diff --git a/src/pystencils/nbackend/kernelcreation/analysis.py b/src/pystencils/nbackend/kernelcreation/analysis.py
index 82bd769b3..98f4887f7 100644
--- a/src/pystencils/nbackend/kernelcreation/analysis.py
+++ b/src/pystencils/nbackend/kernelcreation/analysis.py
@@ -13,11 +13,7 @@ from ...assignment import Assignment
 from ...simp import AssignmentCollection
 from ...transformations import NestedScopes
 
-from ..exceptions import PsInternalCompilerError
-
-
-class KernelConstraintsError(Exception):
-    pass
+from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
 
 class KernelAnalysis:
@@ -39,6 +35,7 @@ class KernelAnalysis:
        the same location.
      - **Independence of Writes:** A weaker requirement than access independence; each field may only be written once
        at each index.
+    - **Dimension of index fields:** Index fields occuring in the kernel must have exactly one spatial dimension.
 
     Knowledge Collection
     --------------------
diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py
index 7e4fad9ba..a77189331 100644
--- a/src/pystencils/nbackend/kernelcreation/context.py
+++ b/src/pystencils/nbackend/kernelcreation/context.py
@@ -1,15 +1,13 @@
 from __future__ import annotations
-from typing import cast
-
 
 from ...field import Field, FieldType
-from ...typing import TypedSymbol, BasicType
+from ...typing import TypedSymbol, BasicType, StructType
 
 from ..arrays import PsLinearizedArray
 from ..types import PsIntegerType
 from ..types.quick import make_type
 from ..constraints import PsKernelConstraint
-from ..exceptions import PsInternalCompilerError
+from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
 from .options import KernelCreationOptions
 from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
@@ -80,12 +78,26 @@ class KernelCreationContext:
         match field.field_type:
             case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX:
                 self._fields_collection.domain_fields.add(field)
+
             case FieldType.BUFFER:
+                if field.spatial_dimensions != 1:
+                    raise KernelConstraintsError(
+                        f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. "
+                        "Buffer fields must be one-dimensional."
+                    )
                 self._fields_collection.buffer_fields.add(field)
+
             case FieldType.INDEXED:
+                if field.spatial_dimensions != 1:
+                    raise KernelConstraintsError(
+                        f"Invalid spatial shape of index field {field.name}: {field.spatial_dimensions}. "
+                        "Index fields must be one-dimensional."
+                    )
                 self._fields_collection.index_fields.add(field)
+
             case FieldType.CUSTOM:
                 self._fields_collection.custom_fields.add(field)
+
             case _:
                 assert False, "unreachable code"
 
@@ -105,8 +117,8 @@ class KernelCreationContext:
                 for s in field.strides
             )
 
-            # TODO: frontend should use new type system
-            element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type)
+            assert isinstance(field.dtype, (BasicType, StructType))
+            element_type = make_type(field.dtype.numpy_dtype)
 
             arr = PsLinearizedArray(
                 field.name, element_type, arr_shape, arr_strides, self.index_dtype
@@ -114,13 +126,11 @@ class KernelCreationContext:
 
             self._arrays[field] = arr
 
-        return self._arrays[field]    
+        return self._arrays[field]
 
     #   Iteration Space
 
     def set_iteration_space(self, ispace: IterationSpace):
-        if self._ispace is not None:
-            raise PsInternalCompilerError("Iteration space was already set.")
         self._ispace = ispace
 
     def get_iteration_space(self) -> IterationSpace:
diff --git a/src/pystencils/nbackend/kernelcreation/defaults.py b/src/pystencils/nbackend/kernelcreation/defaults.py
index c928ef4e4..fc0a602a1 100644
--- a/src/pystencils/nbackend/kernelcreation/defaults.py
+++ b/src/pystencils/nbackend/kernelcreation/defaults.py
@@ -17,7 +17,7 @@ A possibly incomplete list of symbols and types that need to be defined:
 """
 
 from typing import TypeVar, Generic, Callable
-from ..types import PsAbstractType, PsSignedIntegerType
+from ..types import PsAbstractType, PsSignedIntegerType, PsStructType
 from ..typed_expressions import PsTypedVariable
 
 from ...typing import TypedSymbol
@@ -40,14 +40,20 @@ class PsDefaults(Generic[SymbolT]):
         )
         """Default spatial counters"""
 
+        self._index_struct_coordinate_names = ("x", "y", "z")
+        """Default names of spatial coordinate members in index list structures"""
+
         self.index_struct_coordinates = (
-            symcreate("x", self.index_dtype),
-            symcreate("y", self.index_dtype),
-            symcreate("z", self.index_dtype),
+            PsStructType.Member("x", self.index_dtype),
+            PsStructType.Member("y", self.index_dtype),
+            PsStructType.Member("z", self.index_dtype),
         )
-        """Default symbols for spatial coordinates in index list structures"""
+        """Default spatial coordinate members in index list structures"""
+
+        self.sparse_counter_name = "sparse_idx"
+        """Name of the default sparse iteration counter"""
 
-        self.sparse_iteration_counter = symcreate("list_idx", self.index_dtype)
+        self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype)
         """Default sparse iteration counter."""
 
 
diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py
index 69e1dff7d..64d3fa2e9 100644
--- a/src/pystencils/nbackend/kernelcreation/freeze.py
+++ b/src/pystencils/nbackend/kernelcreation/freeze.py
@@ -100,9 +100,9 @@ class FreezeExpressions(SympyToPymbolicMapper):
                 case FieldType.INDEXED:
                     # flake8: noqa
                     sparse_ispace = self._ctx.get_sparse_iteration_space()
-                    #   TODO: make sure index (and all offsets?) are zero
-                    #   TODO: Add sparse iteration counter
-                    raise NotImplementedError()
+                    #   Add sparse iteration counter to offset
+                    assert len(offsets) == 1 # must have been checked by the context
+                    offsets = [offsets[0] + sparse_ispace.sparse_counter]
                 case FieldType.CUSTOM:
                     raise ValueError("Custom fields support only absolute accesses.")
                 case unknown:
diff --git a/src/pystencils/nbackend/kernelcreation/iteration_space.py b/src/pystencils/nbackend/kernelcreation/iteration_space.py
index d38cb3534..a752a16d6 100644
--- a/src/pystencils/nbackend/kernelcreation/iteration_space.py
+++ b/src/pystencils/nbackend/kernelcreation/iteration_space.py
@@ -15,8 +15,10 @@ from ..typed_expressions import (
     PsTypedConstant,
 )
 from ..arrays import PsLinearizedArray
+from ..ast.util import failing_cast
+from ..types import PsStructType, constify
 from .defaults import Pymbolic as Defaults
-from ..exceptions import PsInputError, PsInternalCompilerError
+from ..exceptions import PsInputError, PsInternalCompilerError, KernelConstraintsError
 
 if TYPE_CHECKING:
     from .context import KernelCreationContext
@@ -36,20 +38,24 @@ class IterationSpace(ABC):
        spatial indices.
     """
 
-    def __init__(self, spatial_indices: tuple[PsTypedVariable, ...]):
+    def __init__(self, spatial_indices: Sequence[PsTypedVariable]):
         if len(spatial_indices) == 0:
             raise ValueError("Iteration space must be at least one-dimensional.")
 
-        self._spatial_indices = spatial_indices
+        self._spatial_indices = tuple(spatial_indices)
 
     @property
     def spatial_indices(self) -> tuple[PsTypedVariable, ...]:
         return self._spatial_indices
 
+    @property
+    def dim(self) -> int:
+        return len(self._spatial_indices)
+
 
 class FullIterationSpace(IterationSpace):
     """N-dimensional full iteration space.
-    
+
     Each dimension of the full iteration space is represented by an instance of `FullIterationSpace.Dimension`.
     Dimensions are ordered slowest-to-fastest: The first dimension corresponds to the slowest coordinate,
     translates to the outermost loop, while the last dimension is the fastest coordinate and translates
@@ -144,24 +150,108 @@ class FullIterationSpace(IterationSpace):
 
 
 class SparseIterationSpace(IterationSpace):
-    #   TODO: To properly implement sparse iteration, we still need struct data types
     def __init__(
         self,
-        spatial_indices: tuple[PsTypedVariable, ...],
+        spatial_indices: Sequence[PsTypedVariable],
         index_list: PsLinearizedArray,
+        coordinate_members: Sequence[PsStructType.Member],
+        sparse_counter: PsTypedVariable,
     ):
         super().__init__(spatial_indices)
         self._index_list = index_list
+        self._coord_members = tuple(coordinate_members)
+        self._sparse_counter = sparse_counter
 
     @property
     def index_list(self) -> PsLinearizedArray:
         return self._index_list
+    
+    @property
+    def coordinate_members(self) -> tuple[PsStructType.Member, ...]:
+        return self._coord_members
+
+    @property
+    def sparse_counter(self) -> PsTypedVariable:
+        return self._sparse_counter
+
+
+def get_archetype_field(
+    fields: set[Field],
+    check_compatible_shapes: bool = True,
+    check_same_layouts: bool = True,
+    check_same_dimensions: bool = True,
+):
+    shapes = set(f.spatial_shape for f in fields)
+    fixed_shapes = set(f.spatial_shape for f in fields if f.has_fixed_shape)
+    layouts = set(f.layout for f in fields)
+    dimensionalities = set(f.spatial_dimensions for f in fields)
+
+    if check_same_dimensions and len(dimensionalities) != 1:
+        raise KernelConstraintsError(
+            "All fields must have the same number of spatial dimensions."
+        )
+
+    if check_same_layouts and len(layouts) != 1:
+        raise KernelConstraintsError("All fields must have the same memory layout.")
+
+    if check_compatible_shapes:
+        if len(fixed_shapes) > 0:
+            if len(fixed_shapes) != len(shapes):
+                raise KernelConstraintsError(
+                    "Cannot mix fixed- and variable-shape fields."
+                )
+            if len(fixed_shapes) != 0:
+                raise KernelConstraintsError(
+                    "Fixed-shape fields of different sizes encountered."
+                )
+
+    archetype_field = sorted(fields, key=lambda f: str(f))[0]
+    return archetype_field
 
 
 def create_sparse_iteration_space(
     ctx: KernelCreationContext, assignments: AssignmentCollection
 ) -> IterationSpace:
-    return NotImplemented
+    #   All domain and custom fields must have the same spatial dimensions
+    #   TODO: Must all domain fields have the same shape?
+    archetype_field = get_archetype_field(
+        ctx.fields.domain_fields | ctx.fields.custom_fields,
+        check_compatible_shapes=False,
+        check_same_layouts=False,
+        check_same_dimensions=True,
+    )
+
+    dim = archetype_field.spatial_dimensions
+    coord_members = [
+        PsStructType.Member(name, ctx.index_dtype)
+        for name in Defaults._index_struct_coordinate_names[:dim]
+    ]
+
+    #   Determine index field
+    if ctx.options.index_field is not None:
+        idx_field = ctx.options.index_field
+        idx_arr = ctx.get_array(idx_field)
+        idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type)
+
+        for coord in coord_members:
+            if coord not in idx_struct_type.members:
+                raise PsInputError(
+                    f"Given index field does not provide required coordinate member {coord}"
+                )
+    else:
+        #   TODO: Find index field from the fields list
+        raise NotImplementedError(
+            "Automatic inference of index field for sparse iteration not supported yet."
+        )
+
+    spatial_counters = [
+        PsTypedVariable(name, constify(ctx.index_dtype))
+        for name in Defaults.spatial_counter_names[:dim]
+    ]
+
+    sparse_counter = PsTypedVariable(Defaults.sparse_counter_name, ctx.index_dtype)
+
+    return SparseIterationSpace(spatial_counters, idx_arr, coord_members, sparse_counter)
 
 
 def create_full_iteration_space(
@@ -185,15 +275,12 @@ def create_full_iteration_space(
     # - We have no domain fields, but at least one custom field -> determine common field from custom fields
     # - We have neither domain nor custom fields -> Error
 
-    #   TODO: Re-implement as `get_archetype_field`, check not only shape but also layout equality
-    #   The archetype field must encompass all information about the iteration space: shape, extents, and loop order.
-    from ...transformations import get_common_field
-
     if len(domain_field_accesses) > 0:
-        archetype_field = get_common_field(ctx.fields.domain_fields)
+        archetype_field = get_archetype_field(ctx.fields.domain_fields)
         inferred_gls = max([fa.required_ghost_layers for fa in domain_field_accesses])
     elif len(ctx.fields.custom_fields) > 0:
-        archetype_field = get_common_field(ctx.fields.custom_fields)
+        #   TODO: Warn about inferring iteration space from custom fields
+        archetype_field = get_archetype_field(ctx.fields.custom_fields)
         inferred_gls = 0
     else:
         raise PsInputError(
@@ -204,8 +291,6 @@ def create_full_iteration_space(
     # Otherwise, if an iteration slice was specified, use that
     # Otherwise, use the inferred ghost layers
 
-    from .iteration_space import FullIterationSpace
-
     if ctx.options.ghost_layers is not None:
         return FullIterationSpace.create_with_ghost_layers(
             ctx, archetype_field, ctx.options.ghost_layers
diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
index c87ef2fe4..07614da9a 100644
--- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py
+++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
@@ -23,7 +23,7 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
 
     ispace: IterationSpace = (
         create_sparse_iteration_space(ctx, assignments)
-        if len(ctx.fields.index_fields) > 0
+        if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None
         else create_full_iteration_space(ctx, assignments)
     )
 
@@ -37,22 +37,22 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
 
     match options.target:
         case Target.CPU:
-            from .platform import BasicCpu
+            from .platform import BasicCpuGen
 
             #   TODO: CPU platform should incorporate instruction set info, OpenMP, etc.
-            platform = BasicCpu(ctx)
+            platform_generator = BasicCpuGen(ctx)
         case _:
             #   TODO: CUDA/HIP platform
             #   TODO: SYCL platform (?)
             raise NotImplementedError("Target platform not implemented")
 
-    kernel_ast = platform.apply_iteration_space(kernel_body, ispace)
+    kernel_ast = platform_generator.materialize_iteration_space(kernel_body, ispace)
 
     #   7. Apply optimizations
     #     - Vectorization
     #     - OpenMP
     #     - Loop Splitting, Tiling, Blocking
-    kernel_ast = platform.optimize(kernel_ast)
+    kernel_ast = platform_generator.optimize(kernel_ast)
 
     function = PsKernelFunction(kernel_ast, options.target, name=options.function_name)
     function.add_constraints(*ctx.constraints)
diff --git a/src/pystencils/nbackend/kernelcreation/options.py b/src/pystencils/nbackend/kernelcreation/options.py
index 355050b7e..5f5028a94 100644
--- a/src/pystencils/nbackend/kernelcreation/options.py
+++ b/src/pystencils/nbackend/kernelcreation/options.py
@@ -2,6 +2,8 @@ from typing import Sequence
 from dataclasses import dataclass
 
 from ...enums import Target
+from ...field import Field
+
 from ..exceptions import PsOptionsError
 from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType
 
@@ -43,6 +45,12 @@ class KernelCreationOptions:
     TODO: Specification of valid slices and their behaviour
     """
 
+    index_field: Field | None = None
+    """Index field for a sparse kernel.
+    
+    If this option is set, a sparse kernel with the given field as index field will be generated.
+    """
+
     """Data Types"""
 
     index_dtype: PsIntegerType = SpDefaults.index_dtype
@@ -55,8 +63,13 @@ class KernelCreationOptions:
     """
 
     def __post_init__(self):
-        if self.iteration_slice is not None and self.ghost_layers is not None:
+        if (
+            int(self.iteration_slice is not None)
+            + int(self.ghost_layers is not None)
+            + int(self.index_field is not None)
+            > 1
+        ):
             raise PsOptionsError(
-                "Parameters `iteration_slice` and `ghost_layers` are mutually exclusive; "
+                "Parameters `iteration_slice`, `ghost_layers` and 'index_field` are mutually exclusive; "
                 "at most one of them may be set."
             )
diff --git a/src/pystencils/nbackend/kernelcreation/platform/__init__.py b/src/pystencils/nbackend/kernelcreation/platform/__init__.py
index 20e2c0aae..85b5af9c0 100644
--- a/src/pystencils/nbackend/kernelcreation/platform/__init__.py
+++ b/src/pystencils/nbackend/kernelcreation/platform/__init__.py
@@ -1,5 +1,5 @@
-from .basic_cpu import BasicCpu
+from .basic_cpu import BasicCpuGen
 
 __all__ = [
-    'BasicCpu'
+    'BasicCpuGen'
 ]
diff --git a/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py b/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py
index bba06a8f9..347061f19 100644
--- a/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py
+++ b/src/pystencils/nbackend/kernelcreation/platform/basic_cpu.py
@@ -1,17 +1,26 @@
-from pystencils.nbackend.ast import PsBlock, PsLoop, PsSymbolExpr, PsExpression
-from pystencils.nbackend.kernelcreation.iteration_space import (
+from .platform import PlatformGen
+
+from ..iteration_space import (
     IterationSpace,
     FullIterationSpace,
+    SparseIterationSpace,
 )
-from .platform import Platform
+
+from ...ast import PsDeclaration, PsSymbolExpr, PsExpression, PsLoop, PsBlock
+from ...typed_expressions import PsTypedConstant
+from ...arrays import PsArrayAccess
 
 
-class BasicCpu(Platform):
-    def apply_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock:
+class BasicCpuGen(PlatformGen):
+    def materialize_iteration_space(
+        self, body: PsBlock, ispace: IterationSpace
+    ) -> PsBlock:
         if isinstance(ispace, FullIterationSpace):
-            return self._create_domain_loops(block, ispace)
+            return self._create_domain_loops(body, ispace)
+        elif isinstance(ispace, SparseIterationSpace):
+            return self._create_sparse_loop(body, ispace)
         else:
-            raise NotImplementedError("Iteration space not supported yet.")
+            assert False, "unreachable code"
 
     def optimize(self, kernel: PsBlock) -> PsBlock:
         return kernel
@@ -35,3 +44,28 @@ class BasicCpu(Platform):
             outer_block = PsBlock([loop])
 
         return outer_block
+
+    def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace):
+        mappings = [
+            PsDeclaration(
+                PsSymbolExpr(ctr),
+                PsExpression(
+                    PsArrayAccess(
+                        ispace.index_list.base_pointer, ispace.sparse_counter
+                    ).a.__getattr__(coord.name)
+                ),
+            )
+            for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members)
+        ]
+
+        body = PsBlock(mappings + body.statements)
+
+        loop = PsLoop(
+            PsSymbolExpr(ispace.sparse_counter),
+            PsExpression(PsTypedConstant(0, self._ctx.index_dtype)),
+            PsExpression(ispace.index_list.shape[0]),
+            PsExpression(PsTypedConstant(1, self._ctx.index_dtype)),
+            body,
+        )
+
+        return PsBlock([loop])
diff --git a/src/pystencils/nbackend/kernelcreation/platform/platform.py b/src/pystencils/nbackend/kernelcreation/platform/platform.py
index d7aeec39e..17dfa23f6 100644
--- a/src/pystencils/nbackend/kernelcreation/platform/platform.py
+++ b/src/pystencils/nbackend/kernelcreation/platform/platform.py
@@ -6,7 +6,7 @@ from ..context import KernelCreationContext
 from ..iteration_space import IterationSpace
 
 
-class Platform(ABC):
+class PlatformGen(ABC):
     """Abstract base class for all supported platforms.
     
     The platform performs all target-dependent tasks during code generation:
@@ -18,7 +18,7 @@ class Platform(ABC):
         self._ctx = ctx
 
     @abstractmethod
-    def apply_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock:
+    def materialize_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock:
         ...
 
     @abstractmethod
diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py
index df26a6471..9bc9a462f 100644
--- a/src/pystencils/nbackend/kernelcreation/typification.py
+++ b/src/pystencils/nbackend/kernelcreation/typification.py
@@ -6,7 +6,7 @@ import pymbolic.primitives as pb
 from pymbolic.mapper import Mapper
 
 from .context import KernelCreationContext
-from ..types import PsAbstractType, PsNumericType
+from ..types import PsAbstractType, PsNumericType, deconstify
 from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
 from ..arrays import PsArrayAccess
 from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
@@ -52,7 +52,6 @@ class Typifier(Mapper):
                 new_lhs, lhs_dtype = self.rec(lhs.expression, None)
                 new_rhs, rhs_dtype = self.rec(rhs.expression, lhs_dtype)
                 if lhs_dtype != rhs_dtype:
-                    #   todo: (optional) automatic cast insertion?
                     raise TypificationError(
                         "Mismatched types in assignment: \n"
                         f"    {lhs} <- {rhs}\n"
@@ -67,7 +66,13 @@ class Typifier(Mapper):
 
         return node
 
-    # def rec(self, expr: Any, target_type: PsNumericType | None)
+    """
+    def rec(self, expr: Any, target_type: PsNumericType | None)
+
+    All visitor methods take an expression and the target type of the current context.
+    They shall return the typified expression together with its type.
+    The returned type shall always be non-const, so make sure to call deconstify if necessary.
+    """
 
     def typify_expression(
         self, expr: Any, target_type: PsNumericType | None = None
@@ -80,7 +85,7 @@ class Typifier(Mapper):
         self, var: PsTypedVariable, target_type: PsNumericType | None
     ):
         self._check_target_type(var, var.dtype, target_type)
-        return var, var.dtype
+        return var, deconstify(var.dtype)
 
     def map_variable(
         self, var: pb.Variable, target_type: PsNumericType | None
@@ -88,20 +93,20 @@ class Typifier(Mapper):
         dtype = self._ctx.options.default_dtype
         typed_var = PsTypedVariable(var.name, dtype)
         self._check_target_type(typed_var, dtype, target_type)
-        return typed_var, dtype
+        return typed_var, deconstify(dtype)
 
     def map_constant(
         self, value: Any, target_type: PsNumericType | None
     ) -> tuple[PsTypedConstant, PsNumericType]:
         if isinstance(value, PsTypedConstant):
             self._check_target_type(value, value.dtype, target_type)
-            return value, value.dtype
+            return value, deconstify(value.dtype)
         elif target_type is None:
             raise TypificationError(
                 f"Unable to typify constant {value}: Unknown target type in this context."
             )
         else:
-            return PsTypedConstant(value, target_type), target_type
+            return PsTypedConstant(value, target_type), deconstify(target_type)
 
     #   Array Access
 
@@ -110,7 +115,7 @@ class Typifier(Mapper):
     ) -> tuple[PsArrayAccess, PsNumericType]:
         self._check_target_type(access, access.dtype, target_type)
         index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
-        return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.dtype)
+        return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, deconstify(access.dtype))
 
     #   Arithmetic Expressions
 
@@ -157,7 +162,7 @@ class Typifier(Mapper):
         expr_type: PsAbstractType,
         target_type: PsNumericType | None,
     ):
-        if target_type is not None and expr_type != target_type:
+        if target_type is not None and deconstify(expr_type) != deconstify(target_type):
             raise TypificationError(
                 f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
                 f"  Expression type: {expr_type}\n"
diff --git a/src/pystencils/nbackend/types/__init__.py b/src/pystencils/nbackend/types/__init__.py
index c398aea9d..13deab6b4 100644
--- a/src/pystencils/nbackend/types/__init__.py
+++ b/src/pystencils/nbackend/types/__init__.py
@@ -1,6 +1,7 @@
 from .basic_types import (
     PsAbstractType,
     PsCustomType,
+    PsStructType,
     PsNumericType,
     PsScalarType,
     PsPointerType,
@@ -19,6 +20,7 @@ from .exception import PsTypeError
 __all__ = [
     "PsAbstractType",
     "PsCustomType",
+    "PsStructType",
     "PsPointerType",
     "PsNumericType",
     "PsScalarType",
diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py
index e6b918080..540f28334 100644
--- a/src/pystencils/nbackend/types/basic_types.py
+++ b/src/pystencils/nbackend/types/basic_types.py
@@ -1,11 +1,13 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import final, TypeVar, Any
+from typing import final, TypeVar, Any, Sequence
+from dataclasses import dataclass
 from copy import copy
 
 import numpy as np
 
 from .exception import PsTypeError
+from ..exceptions import PsInternalCompilerError
 
 
 class PsAbstractType(ABC):
@@ -40,12 +42,20 @@ class PsAbstractType(ABC):
     def required_headers(self) -> set[str]:
         """The set of header files required when this type occurs in generated code."""
         return set()
-    
+
     @property
     def itemsize(self) -> int | None:
         """If this type has a valid in-memory size, return that size."""
         return None
 
+    @property
+    def numpy_dtype(self) -> np.dtype | None:
+        """A np.dtype object representing this data type.
+
+        Available both for backward compatibility and for interaction with the numpy-based runtime system.
+        """
+        return None
+
     #   -------------------------------------------------------------------------------------------
     #   Internal virtual operations
     #   -------------------------------------------------------------------------------------------
@@ -142,6 +152,87 @@ class PsPointerType(PsAbstractType):
         return f"PsPointerType( {repr(self.base_type)}, const={self.const} )"
 
 
+class PsStructType(PsAbstractType):
+    """Class to model structured data types.
+
+    A struct type is defined by its sequence of members.
+    The struct may optionally have a name, although the code generator currently does not support named structs
+    and treats them the same way as anonymous structs.
+    """
+
+    @dataclass(frozen=True)
+    class Member:
+        name: str
+        dtype: PsAbstractType
+
+    def __init__(
+        self,
+        members: Sequence[PsStructType.Member | tuple[str, PsAbstractType]],
+        name: str | None = None,
+        const: bool = False,
+    ):
+        super().__init__(const=const)
+
+        self._name = name
+        self._members = tuple(
+            (PsStructType.Member(m[0], m[1]) if isinstance(m, tuple) else m)
+            for m in members
+        )
+
+        names: set[str] = set()
+        for member in self._members:
+            if member.name in names:
+                raise ValueError(f"Duplicate struct member name: {member.name}")
+            names.add(member.name)
+
+    @property
+    def members(self) -> tuple[PsStructType.Member, ...]:
+        return self._members
+
+    @property
+    def name(self) -> str:
+        if self._name is None:
+            raise PsInternalCompilerError(
+                "Cannot retrieve name from anonymous struct type"
+            )
+        return self._name
+
+    @property
+    def anonymous(self) -> bool:
+        return self._name is None
+
+    @property
+    def numpy_dtype(self) -> np.dtype | None:
+        members = [(m.name, m.dtype.numpy_dtype) for m in self._members]
+        return np.dtype(members)
+
+    def _c_string(self) -> str:
+        if self._name is None:
+            # raise PsInternalCompilerError(
+            #     "Cannot retrieve C string for anonymous struct type"
+            # )
+            return "<anonymous>"
+        return self._name
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsStructType):
+            return False
+
+        return (
+            self._base_equal(other)
+            and self._name == other._name
+            and self._members == other._members
+        )
+
+    def __hash__(self) -> int:
+        return hash(("PsStructTupe", self._name, self._members, self._const))
+
+    def __repr__(self) -> str:
+        members = ", ".join(f"{m.dtype} {m.name}" for m in self._members)
+        name = "<anonymous>" if self.anonymous else f"name={self._name}"
+        return f"PsStructType( [{members}], {name}, const={self.const} )"
+
+
 class PsNumericType(PsAbstractType, ABC):
     """Class to model numeric types, which are all types that may occur at the top-level inside
     arithmetic-logical expressions.
@@ -244,6 +335,10 @@ class PsIntegerType(PsScalarType, ABC):
     def itemsize(self) -> int:
         return self.width // 8
 
+    @property
+    def numpy_dtype(self) -> np.dtype | None:
+        return np.dtype(self.NUMPY_TYPES[self._width])
+
     def create_literal(self, value: Any) -> str:
         np_dtype = self.NUMPY_TYPES[self._width]
         if not isinstance(value, np_dtype):
@@ -360,6 +455,10 @@ class PsIeeeFloatType(PsScalarType):
     def itemsize(self) -> int:
         return self.width // 8
 
+    @property
+    def numpy_dtype(self) -> np.dtype | None:
+        return np.dtype(self.NUMPY_TYPES[self._width])
+
     @property
     def required_headers(self) -> set[str]:
         if self._width == 16:
@@ -373,10 +472,14 @@ class PsIeeeFloatType(PsScalarType):
             raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
 
         match self.width:
-            case 16: return f"((half) {value})"  # see include/half_precision.h
-            case 32: return f"{value}f"
-            case 64: return str(value)
-            case _: assert False, "unreachable code"
+            case 16:
+                return f"((half) {value})"  # see include/half_precision.h
+            case 32:
+                return f"{value}f"
+            case 64:
+                return str(value)
+            case _:
+                assert False, "unreachable code"
 
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
diff --git a/src/pystencils/nbackend/types/parsing.py b/src/pystencils/nbackend/types/parsing.py
index 14db20a92..952438f11 100644
--- a/src/pystencils/nbackend/types/parsing.py
+++ b/src/pystencils/nbackend/types/parsing.py
@@ -3,6 +3,7 @@ import numpy as np
 from .basic_types import (
     PsAbstractType,
     PsPointerType,
+    PsStructType,
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
@@ -41,6 +42,22 @@ def interpret_python_type(t: type) -> PsAbstractType:
     raise ValueError(f"Could not interpret Python data type {t} as a pystencils type.")
 
 
+def interpret_numpy_dtype(t: np.dtype) -> PsAbstractType:
+    if t.fields is not None:
+        #   it's a struct
+        members = []
+        for fname, fspec in t.fields.items():
+            members.append(PsStructType.Member(fname, interpret_numpy_dtype(fspec[0])))
+        return PsStructType(members)
+    else:
+        try:
+            return interpret_python_type(t.type)
+        except ValueError:
+            raise ValueError(
+                f"Could not interpret numpy dtype object {t} as a pystencils type."
+            )
+
+
 def parse_type_string(s: str) -> PsAbstractType:
     tokens = s.rsplit("*", 1)
     match tokens:
diff --git a/src/pystencils/nbackend/types/quick.py b/src/pystencils/nbackend/types/quick.py
index b1da0c5e2..cf65897d7 100644
--- a/src/pystencils/nbackend/types/quick.py
+++ b/src/pystencils/nbackend/types/quick.py
@@ -6,6 +6,8 @@ This module is meant to be included whole, e.g. as `from pystencils.nbackend.typ
 
 from __future__ import annotations
 
+import numpy as np
+
 from .basic_types import (
     PsAbstractType,
     PsCustomType,
@@ -17,7 +19,7 @@ from .basic_types import (
     PsIeeeFloatType,
 )
 
-UserTypeSpec = str | type | PsAbstractType
+UserTypeSpec = str | type | np.dtype | PsAbstractType
 
 
 def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
@@ -33,12 +35,14 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
             - No others are supported at the moment
         - Supported Numpy scalar data types (see https://numpy.org/doc/stable/reference/arrays.scalars.html)
           are converted to pystencils scalar data types
+        - Instances of `np.dtype`: Attempt to interpret scalar types like above, and structured types as structs.
         - Instances of `PsAbstractType` will be returned as they are
     """
 
     from .parsing import (
         parse_type_string,
         interpret_python_type,
+        interpret_numpy_dtype
     )
 
     if isinstance(type_spec, PsAbstractType):
@@ -47,6 +51,8 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
         return parse_type_string(type_spec)
     if isinstance(type_spec, type):
         return interpret_python_type(type_spec)
+    if isinstance(type_spec, np.dtype):
+        return interpret_numpy_dtype(type_spec)
     raise ValueError(f"{type_spec} is not a valid type specification.")
 
 
-- 
GitLab