From 0e4677de3a26f669b25c4ab306bbba3862ca8e44 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 27 Mar 2024 17:01:46 +0100
Subject: [PATCH] Refactor Type Handling and Typification

 - Add a `dtype` member to all expression nodes
 - Make the `Typifier` apply `dtype`s to all expressions
 - Adapt transformations and IterationSpace to set data types on created
   expressions
 - Refactor TypeContext and contextual typing interface to be more
   intuitive
 - Refactor the Typifier to apply more operations through the
   TypeContext

Squashed commit of the following:

commit 3e81188a318aa1dc294cf0cd11bf2ec7f62a9b55
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Wed Mar 27 17:00:17 2024 +0100

    Improve typification of integer expressions

     - Check integer type constraint in `_apply_target_type` to correctly catch deferred expressions

commit 63d0cfa5ea1b8a41c9a74bbfcf0618fad03ffa48
Merge: 671f057 075ae35
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Wed Mar 27 16:46:28 2024 +0100

    Merge branch 'backend-rework' into b_refactor_typing

commit 671f0578a39e452504243019dab28d93f0114082
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Tue Mar 26 16:39:43 2024 +0100

    Fix documentation for Typifier and PsExpression

commit 3ec258517ad8a510118265184b5dc7805128dcd3
Author: Frederik Hennig <frederik.hennig@fau.de>
Date:   Mon Mar 25 17:14:21 2024 +0100

    Typing refactor:

      - Annotate all expressions with types
      - Refactor Typifier for cleaner information flow and better
        readability
      - Have iteration space and transformers typify newly created AST nodes
---
 src/pystencils/backend/ast/expressions.py     |  59 +++-
 .../backend/ast/logical_expressions.py        |   1 +
 .../backend/kernelcreation/iteration_space.py |   6 +-
 .../backend/kernelcreation/typification.py    | 271 ++++++++++--------
 src/pystencils/backend/platforms/__init__.py  |   2 +-
 .../backend/platforms/generic_gpu.py          |  41 ++-
 .../erase_anonymous_structs.py                |   7 +-
 .../transformations/select_intrinsics.py      |   4 +-
 tests/nbackend/types/test_types.py            |   2 -
 9 files changed, 241 insertions(+), 152 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index a8384a014..ed94f9c8b 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 from abc import ABC, abstractmethod
-from typing import Sequence, overload, Callable, Any
+from typing import Sequence, overload, Callable, Any, cast
 import operator
 
 from ..symbols import PsSymbol
@@ -14,12 +14,42 @@ from ...types import (
     PsTypeError,
 )
 from .util import failing_cast
+from ..exceptions import PsInternalCompilerError
 
 from .astnode import PsAstNode, PsLeafMixIn
 
 
 class PsExpression(PsAstNode, ABC):
-    """Base class for all expressions."""
+    """Base class for all expressions.
+    
+    **Types:** Each expression should be annotated with its type.
+    Upon construction, the `dtype` property of most expression nodes is unset;
+    only constant expressions, symbol expressions, and array accesses immediately inherit their type from
+    their constant, symbol, or array, respectively.
+
+    The canonical way to add types to newly constructed expressions is through the `Typifier`.
+    It should be run at least once on any expression constructed by the backend.
+
+    The type annotations are used by various transformation passes to make decisions, e.g. in
+    function materialization and intrinsic selection.
+    """
+
+    def __init__(self, dtype: PsType | None = None) -> None:
+        self._dtype = dtype
+
+    @property
+    def dtype(self) -> PsType | None:
+        return self._dtype
+
+    @dtype.setter
+    def dtype(self, dt: PsType):
+        self._dtype = dt
+
+    def get_dtype(self) -> PsType:
+        if self._dtype is None:
+            raise PsInternalCompilerError("No dtype set on this expression yet.")
+
+        return self._dtype
 
     def __add__(self, other: PsExpression) -> PsAdd:
         return PsAdd(self, other)
@@ -70,6 +100,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression):
     __match_args__ = ("symbol",)
 
     def __init__(self, symbol: PsSymbol):
+        super().__init__(symbol.dtype)
         self._symbol = symbol
 
     @property
@@ -97,6 +128,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
     __match_args__ = ("constant",)
 
     def __init__(self, constant: PsConstant):
+        super().__init__(constant.dtype)
         self._constant = constant
 
     @property
@@ -124,6 +156,7 @@ class PsSubscript(PsLvalue, PsExpression):
     __match_args__ = ("base", "index")
 
     def __init__(self, base: PsExpression, index: PsExpression):
+        super().__init__()
         self._base = base
         self._index = index
 
@@ -167,6 +200,7 @@ class PsArrayAccess(PsSubscript):
     def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression):
         super().__init__(PsExpression.make(base_ptr), index)
         self._base_ptr = base_ptr
+        self._dtype = base_ptr.array.element_type
 
     @property
     def base_ptr(self) -> PsArrayBasePointer:
@@ -192,11 +226,6 @@ class PsArrayAccess(PsSubscript):
     def array(self) -> PsLinearizedArray:
         return self._base_ptr.array
 
-    @property
-    def dtype(self) -> PsType:
-        """Data type of this expression, i.e. the element type of the underlying array"""
-        return self._base_ptr.array.element_type
-
     def clone(self) -> PsArrayAccess:
         return PsArrayAccess(self._base_ptr, self._index.clone())
 
@@ -229,15 +258,12 @@ class PsVectorArrayAccess(PsArrayAccess):
         self._stride = stride
         self._alignment = alignment
 
+        self._dtype = self._vector_type
+
     @property
     def vector_entries(self) -> int:
         return self._vector_type.vector_entries
 
-    @property
-    def dtype(self) -> PsVectorType:
-        """Data type of this expression, i.e. the resulting generic vector type"""
-        return self._vector_type
-
     @property
     def stride(self) -> int:
         return self._stride
@@ -245,6 +271,9 @@ class PsVectorArrayAccess(PsArrayAccess):
     @property
     def alignment(self) -> int:
         return self._alignment
+    
+    def get_vector_type(self) -> PsVectorType:
+        return cast(PsVectorType, self._dtype)
 
     def clone(self) -> PsVectorArrayAccess:
         return PsVectorArrayAccess(
@@ -271,6 +300,7 @@ class PsLookup(PsExpression, PsLvalue):
     __match_args__ = ("aggregate", "member_name")
 
     def __init__(self, aggregate: PsExpression, member_name: str) -> None:
+        super().__init__()
         self._aggregate = aggregate
         self._member_name = member_name
 
@@ -310,6 +340,8 @@ class PsCall(PsExpression):
                 f"Argument count mismatch: Cannot apply function {function} to {len(args)} arguments."
             )
 
+        super().__init__()
+
         self._function = function
         self._args = list(args)
 
@@ -349,6 +381,7 @@ class PsUnOp(PsExpression):
     __match_args__ = ("operand",)
 
     def __init__(self, operand: PsExpression):
+        super().__init__()
         self._operand = operand
 
     @property
@@ -419,6 +452,7 @@ class PsBinOp(PsExpression):
     __match_args__ = ("operand1", "operand2")
 
     def __init__(self, op1: PsExpression, op2: PsExpression):
+        super().__init__()
         self._op1 = op1
         self._op2 = op2
 
@@ -527,6 +561,7 @@ class PsArrayInitList(PsExpression):
     __match_args__ = ("items",)
 
     def __init__(self, items: Sequence[PsExpression]):
+        super().__init__()
         self._items = list(items)
 
     @property
diff --git a/src/pystencils/backend/ast/logical_expressions.py b/src/pystencils/backend/ast/logical_expressions.py
index 49fbf68f0..2d739e020 100644
--- a/src/pystencils/backend/ast/logical_expressions.py
+++ b/src/pystencils/backend/ast/logical_expressions.py
@@ -10,6 +10,7 @@ class PsLogicalExpression(PsExpression):
     __match_args__ = ("operand1", "operand2")
 
     def __init__(self, op1: PsExpression, op2: PsExpression):
+        super().__init__()
         self._op1 = op1
         self._op2 = op2
 
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index 7d6404380..382adf7b6 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -105,9 +105,13 @@ class FullIterationSpace(IterationSpace):
 
         spatial_shape = archetype_array.shape[:dim]
 
+        from .typification import Typifier
+
+        typify = Typifier(ctx)
+
         dimensions = [
             FullIterationSpace.Dimension(
-                gl_left, PsExpression.make(shape) - gl_right, one, ctr
+                gl_left, typify(PsExpression.make(shape) - gl_right), one, ctr
             )
             for (gl_left, gl_right), shape, ctr in zip(
                 ghost_layer_exprs, spatial_shape, counters, strict=True
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 3fbb9c1a8..dcfb0f548 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -10,6 +10,7 @@ from ...types import (
     PsIntegerType,
     PsArrayType,
     PsDereferencableType,
+    PsPointerType,
     PsBoolType,
     deconstify,
 )
@@ -30,6 +31,8 @@ from ..ast.expressions import (
     PsBitwiseXor,
     PsCall,
     PsCast,
+    PsDeref,
+    PsAddressOf,
     PsConstantExpr,
     PsIntDiv,
     PsLeftShift,
@@ -53,45 +56,97 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
 class TypeContext:
     def __init__(self, target_type: PsType | None = None):
         self._target_type = deconstify(target_type) if target_type is not None else None
-        self._deferred_constants: list[PsConstantExpr] = []
+        self._deferred_exprs: list[PsExpression] = []
 
-    def typify_constant(self, constexpr: PsConstantExpr) -> None:
-        if self._target_type is None:
-            self._deferred_constants.append(constexpr)
-        elif not isinstance(self._target_type, PsNumericType):
-            raise TypificationError(
-                f"Can't typify constant with non-numeric type {self._target_type}"
-            )
-        else:
-            constexpr.constant.apply_dtype(self._target_type)
+    def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
+        """Applies the given ``dtype`` to the given expression inside this type context.
 
-    def apply_and_check(self, expr: PsExpression, expr_type: PsType):
+        The given expression will be covered by this type context.
+        If the context's target_type is already known, it must be compatible with the given dtype.
+        If the target type is still unknown, target_type is set to dtype and retroactively applied
+        to all deferred expressions.
         """
-        If no target type has been set yet, establishes expr_type as the target type
-        and typifies all deferred expressions.
 
-        Otherwise, checks if expression type and target type are compatible.
+        dtype = deconstify(dtype)
+
+        if self._target_type is not None and dtype != self._target_type:
+            raise TypificationError(
+                f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
+                f"  Expression type: {dtype}\n"
+                f"      Target type: {self._target_type}"
+            )
+        else:
+            self._target_type = dtype
+            self._propagate_target_type()
+
+        if expr is not None:
+            if expr.dtype is None:
+                self._apply_target_type(expr)
+            elif deconstify(expr.dtype) != self._target_type:
+                raise TypificationError(
+                    "Type conflict: Predefined expression type did not match the context's target type\n"
+                    f"  Expression type: {dtype}\n"
+                    f"      Target type: {self._target_type}"
+                )
+
+    def infer_dtype(self, expr: PsExpression):
+        """Infer the data type for the given expression.
+
+        If the target_type of this context is already known, it will be applied to the given expression.
+        Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
+        called on this context.
+
+        It the expression already has a data type set, it must be equal to the inferred type.
         """
-        expr_type = deconstify(expr_type)
 
         if self._target_type is None:
-            self._target_type = expr_type
-
-            for dc in self._deferred_constants:
-                if not isinstance(self._target_type, PsNumericType):
+            self._deferred_exprs.append(expr)
+        else:
+            self._apply_target_type(expr)
+
+    def _propagate_target_type(self):
+        for expr in self._deferred_exprs:
+            self._apply_target_type(expr)
+        self._deferred_exprs = []
+
+    def _apply_target_type(self, expr: PsExpression):
+        assert self._target_type is not None
+
+        if expr.dtype is not None:
+            if deconstify(expr.dtype) != self.target_type:
+                raise TypificationError(
+                    f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
+                    f"  Expression type: {expr.dtype}\n"
+                    f"      Target type: {self._target_type}"
+                )
+        else:
+            match expr:
+                case PsConstantExpr(c):
+                    if not isinstance(self._target_type, PsNumericType):
+                        raise TypificationError(
+                            f"Can't typify constant with non-numeric type {self._target_type}"
+                        )
+                    c.apply_dtype(self._target_type)
+                
+                case PsSymbolExpr(symb):
+                    symb.apply_dtype(self._target_type)
+
+                case (
+                    PsIntDiv()
+                    | PsLeftShift()
+                    | PsRightShift()
+                    | PsBitwiseAnd()
+                    | PsBitwiseXor()
+                    | PsBitwiseOr()
+                ) if not isinstance(self._target_type, PsIntegerType):
                     raise TypificationError(
-                        f"Can't typify constant with non-numeric type {self._target_type}"
+                        f"Integer operation encountered in non-integer type context:\n"
+                        f"    Expression: {expr}"
+                        f"  Type Context: {self._target_type}"
                     )
-                dc.constant.apply_dtype(self._target_type)
 
-            self._deferred_constants = []
-
-        elif expr_type != self._target_type:
-            raise TypificationError(
-                f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
-                f"  Expression type: {expr_type}\n"
-                f"      Target type: {self._target_type}"
-            )
+            expr.dtype = self._target_type
+        # endif
 
     @property
     def target_type(self) -> PsType | None:
@@ -99,46 +154,28 @@ class TypeContext:
 
 
 class Typifier:
-    """Typifier for untyped expressions.
-
-    The typifier, when called with an AST node, will attempt to figure out
-    the types for all untyped expressions within the node.
-    Plain variables will be assigned a type according to `ctx.options.default_dtype`,
-    constants will be converted to typed constants according to the contextual typing scheme
-    described below.
-
-    Contextual Typing
-    -----------------
-
-    The contextual typifier covers the expression tree with disjoint typing contexts.
-    The idea is that all nodes covered by a typing context must have the exact same type.
-    Starting at an expression's root, the typifier attempts to expand a typing context as far as possible
-    toward the leaves.
-    This happens implicitly during the recursive traversal of the expression tree.
-
-    At an interior node, which is modelled as a function applied to a number of arguments, producing a result,
-    that function's signature governs context expansion. Let T be the function's return type; then the context
-    is expanded to each argument expression that also is of type T.
-
-    If a function parameter is of type S != T, a new type context is created for it. If the type S is already fixed
-    by the function signature, it will be the target type of the new context.
-
-    At the tree's leaves, types are applied and checked. By the above propagation rule, all leaves that share a typing
-    context must have the exact same type (modulo constness). There the actual type checking happens.
-    If a variable is encountered and the context does not yet have a target type, it is set to the variable's type.
-    If a constant is encountered, it is typified using the current target type.
-    If no target type is known yet, the constant will first be instantiated as a DeferredTypedConstant,
-    and stashed in the context.
-    As soon as the context learns its target type, it is applied to all deferred constants.
-
-    In addition to leaves, some interior nodes may also have to be checked against the target type.
-    In particular, these are array accesses, struct member accesses, and calls to functions with a fixed
-    return type.
-
-    When a context is 'closed' during the recursive unwinding, it shall be an error if it still contains unresolved
-    constants.
-
-    TODO: The context shall keep track of it's target type's origin to aid in producing helpful error messages.
+    """Apply data types to expressions.
+
+    The Typifier will traverse the AST and apply a contextual typing scheme to figure out
+    the data types of all encountered expressions.
+    To this end, it covers each expression tree with a set of disjoint typing contexts.
+    All nodes covered by the same typing context must have the same type.
+
+    Starting from an expression's root, a typing context is implicitly expanded through
+    the recursive descent into a node's children. In particular, a child is typified within
+    the same context as its parent if the node's semantics require parent and child to have
+    the same type (e.g. at arithmetic operators, mathematical functions, etc.).
+    If a node's child is required to have a different type, a new context is opened.
+
+    For each typing context, its target type is prescribed by the first node encountered during traversal
+    whose type is fixed according to its typing rules. All other nodes covered by the context must share
+    that type.
+
+    The types of arithmetic operators, mathematical functions, and untyped constants are
+    inferred from their context's target type. If one of these is encountered while no target type is set yet
+    in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
+    as the target type is fixed.
+
     """
 
     def __init__(self, ctx: KernelCreationContext):
@@ -152,7 +189,7 @@ class Typifier:
         return node
 
     def typify_expression(
-        self, expr: PsExpression, target_type: PsNumericType | None = None
+        self, expr: PsExpression, target_type: PsType | None = None
     ) -> tuple[PsExpression, PsType]:
         tc = TypeContext(target_type)
         self.visit_expr(expr, tc)
@@ -202,25 +239,22 @@ class Typifier:
     def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
         """Recursive processing of expression nodes"""
         match expr:
-            case PsSymbolExpr(symb):
-                if symb.dtype is None:
-                    dtype = self._ctx.default_dtype
-                    symb.apply_dtype(dtype)
-                tc.apply_and_check(expr, symb.get_dtype())
-
-            case PsConstantExpr(constant):
-                if constant.dtype is not None:
-                    tc.apply_and_check(expr, constant.get_dtype())
+            case PsSymbolExpr(_):
+                if expr.dtype is None:
+                    tc.apply_dtype(expr, self._ctx.default_dtype)
                 else:
-                    tc.typify_constant(expr)
+                    tc.apply_dtype(expr, expr.dtype)
+
+            case PsConstantExpr(_):
+                tc.infer_dtype(expr)
 
-            case PsArrayAccess(_, idx):
-                tc.apply_and_check(expr, expr.dtype)
+            case PsArrayAccess(bptr, idx):
+                tc.apply_dtype(expr, bptr.array.element_type)
 
                 index_tc = TypeContext()
                 self.visit_expr(idx, index_tc)
                 if index_tc.target_type is None:
-                    index_tc.apply_and_check(idx, self._ctx.index_dtype)
+                    index_tc.apply_dtype(idx, self._ctx.index_dtype)
                 elif not isinstance(index_tc.target_type, PsIntegerType):
                     raise TypificationError(
                         f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
@@ -235,17 +269,40 @@ class Typifier:
                         "Type of subscript base is not subscriptable."
                     )
 
-                tc.apply_and_check(expr, arr_tc.target_type.base_type)
+                tc.apply_dtype(expr, arr_tc.target_type.base_type)
 
                 index_tc = TypeContext()
                 self.visit_expr(idx, index_tc)
                 if index_tc.target_type is None:
-                    index_tc.apply_and_check(idx, self._ctx.index_dtype)
+                    index_tc.apply_dtype(idx, self._ctx.index_dtype)
                 elif not isinstance(index_tc.target_type, PsIntegerType):
                     raise TypificationError(
                         f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
                     )
 
+            case PsDeref(ptr):
+                ptr_tc = TypeContext()
+                self.visit_expr(ptr, ptr_tc)
+
+                if not isinstance(ptr_tc.target_type, PsDereferencableType):
+                    raise TypificationError(
+                        "Type of argument to a Deref is not dereferencable"
+                    )
+
+                tc.apply_dtype(expr, ptr_tc.target_type.base_type)
+
+            case PsAddressOf(arg):
+                arg_tc = TypeContext()
+                self.visit_expr(arg, arg_tc)
+
+                if arg_tc.target_type is None:
+                    raise TypificationError(
+                        f"Unable to determine type of argument to AddressOf: {arg}"
+                    )
+
+                ptr_type = PsPointerType(arg_tc.target_type, True)
+                tc.apply_dtype(expr, ptr_type)
+
             case PsLookup(aggr, member_name):
                 aggr_tc = TypeContext(None)
                 self.visit_expr(aggr, aggr_tc)
@@ -262,47 +319,19 @@ class Typifier:
                         f"Aggregate of type {aggr_type} does not have a member {member}."
                     )
 
-                tc.apply_and_check(expr, member.dtype)
-
-            # integer operations
-            case (
-                PsIntDiv(op1, op2)
-                | PsLeftShift(op1, op2)
-                | PsRightShift(op1, op2)
-                | PsBitwiseAnd(op1, op2)
-                | PsBitwiseXor(op1, op2)
-                | PsBitwiseOr(op1, op2)
-            ):
-                if tc.target_type is not None and not isinstance(
-                    tc.target_type, PsIntegerType
-                ):
-                    raise TypificationError(
-                        f"Integer expression used in non-integer context.\n"
-                        f"  Integer expression: {expr}\n"
-                        f"        Context type: {tc.target_type}"
-                    )
-
-                self.visit_expr(op1, tc)
-                self.visit_expr(op2, tc)
-
-                if tc.target_type is None:
-                    raise TypificationError(
-                        f"Unable to infer type of integer expression {expr}."
-                    )
-                elif not isinstance(tc.target_type, PsIntegerType):
-                    raise TypificationError(
-                        f"Argument(s) to integer function are non-integer in expression {expr}."
-                    )
+                tc.apply_dtype(expr, member.dtype)
 
             case PsBinOp(op1, op2):
                 self.visit_expr(op1, tc)
                 self.visit_expr(op2, tc)
+                tc.infer_dtype(expr)
 
             case PsCall(function, args):
                 match function:
                     case PsMathFunction():
                         for arg in args:
                             self.visit_expr(arg, tc)
+                        tc.infer_dtype(expr)
                     case _:
                         raise TypificationError(
                             f"Don't know how to typify calls to {function}"
@@ -329,14 +358,14 @@ class Typifier:
                             f"{len(items)} items as {tc.target_type}"
                         )
                     else:
-                        items_tc.apply_and_check(expr, tc.target_type.base_type)
+                        items_tc.apply_dtype(None, tc.target_type.base_type)
                 else:
                     arr_type = PsArrayType(items_tc.target_type, len(items))
-                    tc.apply_and_check(expr, arr_type)
+                    tc.apply_dtype(expr, arr_type)
 
-            case PsCast(dtype, operand):
-                self.visit_expr(operand, TypeContext())
-                tc.apply_and_check(expr, dtype)
+            case PsCast(dtype, arg):
+                self.visit_expr(arg, TypeContext())
+                tc.apply_dtype(expr, dtype)
 
             case _:
                 raise NotImplementedError(f"Can't typify {expr}")
diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py
index 355c28d8f..0b816bf93 100644
--- a/src/pystencils/backend/platforms/__init__.py
+++ b/src/pystencils/backend/platforms/__init__.py
@@ -9,5 +9,5 @@ __all__ = [
     "GenericVectorCpu",
     "X86VectorCpu",
     "X86VectorArch",
-    "GenericGpu"
+    "GenericGpu",
 ]
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 1e7d958f7..79ab6f9ec 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -3,11 +3,12 @@ from .platform import Platform
 from ..kernelcreation.iteration_space import (
     IterationSpace,
     FullIterationSpace,
-    SparseIterationSpace,
+    # SparseIterationSpace,
 )
 
 from ..ast.structural import PsBlock, PsConditional
 from ..ast.expressions import (
+    PsExpression,
     PsSymbolExpr,
     PsAdd,
 )
@@ -17,10 +18,18 @@ from ..symbols import PsSymbol
 
 int32 = PsSignedIntegerType(width=32, const=False)
 
-BLOCK_IDX = [PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ('x', 'y', 'z')]
-THREAD_IDX = [PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ('x', 'y', 'z')]
-BLOCK_DIM = [PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ('x', 'y', 'z')]
-GRID_DIM = [PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ('x', 'y', 'z')]
+BLOCK_IDX = [
+    PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
+]
+THREAD_IDX = [
+    PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
+]
+BLOCK_DIM = [
+    PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
+]
+GRID_DIM = [
+    PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
+]
 
 
 class GenericGpu(Platform):
@@ -29,7 +38,9 @@ class GenericGpu(Platform):
     def required_headers(self) -> set[str]:
         return {"gpu_defines.h"}
 
-    def materialize_iteration_space(self, body: PsBlock, ispace: IterationSpace) -> PsBlock:
+    def materialize_iteration_space(
+        self, body: PsBlock, ispace: IterationSpace
+    ) -> PsBlock:
         if isinstance(ispace, FullIterationSpace):
             return self._guard_full_iteration_space(body, ispace)
         else:
@@ -37,13 +48,17 @@ class GenericGpu(Platform):
 
     def cuda_indices(self, dim):
         block_size = BLOCK_DIM
-        indices = [block_index * bs + thread_idx
-                   for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)]
+        indices = [
+            block_index * bs + thread_idx
+            for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)
+        ]
 
         return indices[:dim]
 
     #   Internals
-    def _guard_full_iteration_space(self, body: PsBlock, ispace: FullIterationSpace) -> PsBlock:
+    def _guard_full_iteration_space(
+        self, body: PsBlock, ispace: FullIterationSpace
+    ) -> PsBlock:
 
         dimensions = ispace.dimensions
 
@@ -53,12 +68,14 @@ class GenericGpu(Platform):
             loop_order = archetype_field.layout
             dimensions = [dimensions[coordinate] for coordinate in loop_order]
 
-        start = [PsAdd(c, d.start) for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1])]
+        start = [
+            PsAdd(c, d.start)
+            for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1])
+        ]
         conditions = [PsLt(c, d.stop) for c, d in zip(start, dimensions[::-1])]
 
-        condition = conditions[0]
+        condition: PsExpression = conditions[0]
         for c in conditions[1:]:
             condition = PsAnd(condition, c)
 
         return PsBlock([PsConditional(condition, body)])
-
diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py
index c946ae7bb..03d79a689 100644
--- a/src/pystencils/backend/transformations/erase_anonymous_structs.py
+++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py
@@ -12,6 +12,7 @@ from ..ast.expressions import (
     PsAddressOf,
     PsCast,
 )
+from ..kernelcreation import Typifier
 from ..arrays import PsArrayBasePointer, TypeErasedBasePointer
 from ...types import PsStructType, PsPointerType
 
@@ -98,6 +99,10 @@ class EraseAnonymousStructTypes:
         )
         type_erased_access = PsArrayAccess(type_erased_bp, byte_index)
 
-        return PsDeref(
+        deref = PsDeref(
             PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access))
         )
+
+        typify = Typifier(self._ctx)
+        deref = typify(deref)
+        return deref
diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py
index e587ba129..7972de069 100644
--- a/src/pystencils/backend/transformations/select_intrinsics.py
+++ b/src/pystencils/backend/transformations/select_intrinsics.py
@@ -68,7 +68,7 @@ class MaterializeVectorIntrinsics:
         match node:
             case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorArrayAccess):
                 vc = VecTypeCtx()
-                vc.set(lhs.dtype)
+                vc.set(lhs.get_vector_type())
                 store_arg = self.visit_expr(rhs, vc)
                 return PsStatement(self._platform.vector_store(lhs, store_arg))
             case PsExpression():
@@ -95,7 +95,7 @@ class MaterializeVectorIntrinsics:
                     return expr
 
             case PsVectorArrayAccess():
-                vc.set(expr.dtype)
+                vc.set(expr.get_vector_type())
                 return self._platform.vector_load(expr)
 
             case PsBinOp(op1, op2):
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 487f77783..204ee24cf 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -32,10 +32,8 @@ def test_parsing_negative():
         "const notatype * const",
         "cnost uint32_t",
         "uint45_t",
-        "int",  # plain ints are ambiguous
         "float float",
         "double * int",
-        "bool",
     ]
 
     for spec in bad_specs:
-- 
GitLab