From 0e4677de3a26f669b25c4ab306bbba3862ca8e44 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 27 Mar 2024 17:01:46 +0100 Subject: [PATCH] Refactor Type Handling and Typification - Add a `dtype` member to all expression nodes - Make the `Typifier` apply `dtype`s to all expressions - Adapt transformations and IterationSpace to set data types on created expressions - Refactor TypeContext and contextual typing interface to be more intuitive - Refactor the Typifier to apply more operations through the TypeContext Squashed commit of the following: commit 3e81188a318aa1dc294cf0cd11bf2ec7f62a9b55 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 17:00:17 2024 +0100 Improve typification of integer expressions - Check integer type constraint in `_apply_target_type` to correctly catch deferred expressions commit 63d0cfa5ea1b8a41c9a74bbfcf0618fad03ffa48 Merge: 671f057 075ae35 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Mar 27 16:46:28 2024 +0100 Merge branch 'backend-rework' into b_refactor_typing commit 671f0578a39e452504243019dab28d93f0114082 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Mar 26 16:39:43 2024 +0100 Fix documentation for Typifier and PsExpression commit 3ec258517ad8a510118265184b5dc7805128dcd3 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Mon Mar 25 17:14:21 2024 +0100 Typing refactor: - Annotate all expressions with types - Refactor Typifier for cleaner information flow and better readability - Have iteration space and transformers typify newly created AST nodes --- src/pystencils/backend/ast/expressions.py | 59 +++- .../backend/ast/logical_expressions.py | 1 + .../backend/kernelcreation/iteration_space.py | 6 +- .../backend/kernelcreation/typification.py | 271 ++++++++++-------- src/pystencils/backend/platforms/__init__.py | 2 +- .../backend/platforms/generic_gpu.py | 41 ++- .../erase_anonymous_structs.py | 7 +- .../transformations/select_intrinsics.py | 4 +- tests/nbackend/types/test_types.py | 2 - 9 files changed, 241 insertions(+), 152 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index a8384a014..ed94f9c8b 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Sequence, overload, Callable, Any +from typing import Sequence, overload, Callable, Any, cast import operator from ..symbols import PsSymbol @@ -14,12 +14,42 @@ from ...types import ( PsTypeError, ) from .util import failing_cast +from ..exceptions import PsInternalCompilerError from .astnode import PsAstNode, PsLeafMixIn class PsExpression(PsAstNode, ABC): - """Base class for all expressions.""" + """Base class for all expressions. + + **Types:** Each expression should be annotated with its type. + Upon construction, the `dtype` property of most expression nodes is unset; + only constant expressions, symbol expressions, and array accesses immediately inherit their type from + their constant, symbol, or array, respectively. + + The canonical way to add types to newly constructed expressions is through the `Typifier`. + It should be run at least once on any expression constructed by the backend. + + The type annotations are used by various transformation passes to make decisions, e.g. in + function materialization and intrinsic selection. + """ + + def __init__(self, dtype: PsType | None = None) -> None: + self._dtype = dtype + + @property + def dtype(self) -> PsType | None: + return self._dtype + + @dtype.setter + def dtype(self, dt: PsType): + self._dtype = dt + + def get_dtype(self) -> PsType: + if self._dtype is None: + raise PsInternalCompilerError("No dtype set on this expression yet.") + + return self._dtype def __add__(self, other: PsExpression) -> PsAdd: return PsAdd(self, other) @@ -70,6 +100,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): __match_args__ = ("symbol",) def __init__(self, symbol: PsSymbol): + super().__init__(symbol.dtype) self._symbol = symbol @property @@ -97,6 +128,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): __match_args__ = ("constant",) def __init__(self, constant: PsConstant): + super().__init__(constant.dtype) self._constant = constant @property @@ -124,6 +156,7 @@ class PsSubscript(PsLvalue, PsExpression): __match_args__ = ("base", "index") def __init__(self, base: PsExpression, index: PsExpression): + super().__init__() self._base = base self._index = index @@ -167,6 +200,7 @@ class PsArrayAccess(PsSubscript): def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression): super().__init__(PsExpression.make(base_ptr), index) self._base_ptr = base_ptr + self._dtype = base_ptr.array.element_type @property def base_ptr(self) -> PsArrayBasePointer: @@ -192,11 +226,6 @@ class PsArrayAccess(PsSubscript): def array(self) -> PsLinearizedArray: return self._base_ptr.array - @property - def dtype(self) -> PsType: - """Data type of this expression, i.e. the element type of the underlying array""" - return self._base_ptr.array.element_type - def clone(self) -> PsArrayAccess: return PsArrayAccess(self._base_ptr, self._index.clone()) @@ -229,15 +258,12 @@ class PsVectorArrayAccess(PsArrayAccess): self._stride = stride self._alignment = alignment + self._dtype = self._vector_type + @property def vector_entries(self) -> int: return self._vector_type.vector_entries - @property - def dtype(self) -> PsVectorType: - """Data type of this expression, i.e. the resulting generic vector type""" - return self._vector_type - @property def stride(self) -> int: return self._stride @@ -245,6 +271,9 @@ class PsVectorArrayAccess(PsArrayAccess): @property def alignment(self) -> int: return self._alignment + + def get_vector_type(self) -> PsVectorType: + return cast(PsVectorType, self._dtype) def clone(self) -> PsVectorArrayAccess: return PsVectorArrayAccess( @@ -271,6 +300,7 @@ class PsLookup(PsExpression, PsLvalue): __match_args__ = ("aggregate", "member_name") def __init__(self, aggregate: PsExpression, member_name: str) -> None: + super().__init__() self._aggregate = aggregate self._member_name = member_name @@ -310,6 +340,8 @@ class PsCall(PsExpression): f"Argument count mismatch: Cannot apply function {function} to {len(args)} arguments." ) + super().__init__() + self._function = function self._args = list(args) @@ -349,6 +381,7 @@ class PsUnOp(PsExpression): __match_args__ = ("operand",) def __init__(self, operand: PsExpression): + super().__init__() self._operand = operand @property @@ -419,6 +452,7 @@ class PsBinOp(PsExpression): __match_args__ = ("operand1", "operand2") def __init__(self, op1: PsExpression, op2: PsExpression): + super().__init__() self._op1 = op1 self._op2 = op2 @@ -527,6 +561,7 @@ class PsArrayInitList(PsExpression): __match_args__ = ("items",) def __init__(self, items: Sequence[PsExpression]): + super().__init__() self._items = list(items) @property diff --git a/src/pystencils/backend/ast/logical_expressions.py b/src/pystencils/backend/ast/logical_expressions.py index 49fbf68f0..2d739e020 100644 --- a/src/pystencils/backend/ast/logical_expressions.py +++ b/src/pystencils/backend/ast/logical_expressions.py @@ -10,6 +10,7 @@ class PsLogicalExpression(PsExpression): __match_args__ = ("operand1", "operand2") def __init__(self, op1: PsExpression, op2: PsExpression): + super().__init__() self._op1 = op1 self._op2 = op2 diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 7d6404380..382adf7b6 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -105,9 +105,13 @@ class FullIterationSpace(IterationSpace): spatial_shape = archetype_array.shape[:dim] + from .typification import Typifier + + typify = Typifier(ctx) + dimensions = [ FullIterationSpace.Dimension( - gl_left, PsExpression.make(shape) - gl_right, one, ctr + gl_left, typify(PsExpression.make(shape) - gl_right), one, ctr ) for (gl_left, gl_right), shape, ctr in zip( ghost_layer_exprs, spatial_shape, counters, strict=True diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 3fbb9c1a8..dcfb0f548 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -10,6 +10,7 @@ from ...types import ( PsIntegerType, PsArrayType, PsDereferencableType, + PsPointerType, PsBoolType, deconstify, ) @@ -30,6 +31,8 @@ from ..ast.expressions import ( PsBitwiseXor, PsCall, PsCast, + PsDeref, + PsAddressOf, PsConstantExpr, PsIntDiv, PsLeftShift, @@ -53,45 +56,97 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: def __init__(self, target_type: PsType | None = None): self._target_type = deconstify(target_type) if target_type is not None else None - self._deferred_constants: list[PsConstantExpr] = [] + self._deferred_exprs: list[PsExpression] = [] - def typify_constant(self, constexpr: PsConstantExpr) -> None: - if self._target_type is None: - self._deferred_constants.append(constexpr) - elif not isinstance(self._target_type, PsNumericType): - raise TypificationError( - f"Can't typify constant with non-numeric type {self._target_type}" - ) - else: - constexpr.constant.apply_dtype(self._target_type) + def apply_dtype(self, expr: PsExpression | None, dtype: PsType): + """Applies the given ``dtype`` to the given expression inside this type context. - def apply_and_check(self, expr: PsExpression, expr_type: PsType): + The given expression will be covered by this type context. + If the context's target_type is already known, it must be compatible with the given dtype. + If the target type is still unknown, target_type is set to dtype and retroactively applied + to all deferred expressions. """ - If no target type has been set yet, establishes expr_type as the target type - and typifies all deferred expressions. - Otherwise, checks if expression type and target type are compatible. + dtype = deconstify(dtype) + + if self._target_type is not None and dtype != self._target_type: + raise TypificationError( + f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" + f" Expression type: {dtype}\n" + f" Target type: {self._target_type}" + ) + else: + self._target_type = dtype + self._propagate_target_type() + + if expr is not None: + if expr.dtype is None: + self._apply_target_type(expr) + elif deconstify(expr.dtype) != self._target_type: + raise TypificationError( + "Type conflict: Predefined expression type did not match the context's target type\n" + f" Expression type: {dtype}\n" + f" Target type: {self._target_type}" + ) + + def infer_dtype(self, expr: PsExpression): + """Infer the data type for the given expression. + + If the target_type of this context is already known, it will be applied to the given expression. + Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is + called on this context. + + It the expression already has a data type set, it must be equal to the inferred type. """ - expr_type = deconstify(expr_type) if self._target_type is None: - self._target_type = expr_type - - for dc in self._deferred_constants: - if not isinstance(self._target_type, PsNumericType): + self._deferred_exprs.append(expr) + else: + self._apply_target_type(expr) + + def _propagate_target_type(self): + for expr in self._deferred_exprs: + self._apply_target_type(expr) + self._deferred_exprs = [] + + def _apply_target_type(self, expr: PsExpression): + assert self._target_type is not None + + if expr.dtype is not None: + if deconstify(expr.dtype) != self.target_type: + raise TypificationError( + f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" + f" Expression type: {expr.dtype}\n" + f" Target type: {self._target_type}" + ) + else: + match expr: + case PsConstantExpr(c): + if not isinstance(self._target_type, PsNumericType): + raise TypificationError( + f"Can't typify constant with non-numeric type {self._target_type}" + ) + c.apply_dtype(self._target_type) + + case PsSymbolExpr(symb): + symb.apply_dtype(self._target_type) + + case ( + PsIntDiv() + | PsLeftShift() + | PsRightShift() + | PsBitwiseAnd() + | PsBitwiseXor() + | PsBitwiseOr() + ) if not isinstance(self._target_type, PsIntegerType): raise TypificationError( - f"Can't typify constant with non-numeric type {self._target_type}" + f"Integer operation encountered in non-integer type context:\n" + f" Expression: {expr}" + f" Type Context: {self._target_type}" ) - dc.constant.apply_dtype(self._target_type) - self._deferred_constants = [] - - elif expr_type != self._target_type: - raise TypificationError( - f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" - f" Expression type: {expr_type}\n" - f" Target type: {self._target_type}" - ) + expr.dtype = self._target_type + # endif @property def target_type(self) -> PsType | None: @@ -99,46 +154,28 @@ class TypeContext: class Typifier: - """Typifier for untyped expressions. - - The typifier, when called with an AST node, will attempt to figure out - the types for all untyped expressions within the node. - Plain variables will be assigned a type according to `ctx.options.default_dtype`, - constants will be converted to typed constants according to the contextual typing scheme - described below. - - Contextual Typing - ----------------- - - The contextual typifier covers the expression tree with disjoint typing contexts. - The idea is that all nodes covered by a typing context must have the exact same type. - Starting at an expression's root, the typifier attempts to expand a typing context as far as possible - toward the leaves. - This happens implicitly during the recursive traversal of the expression tree. - - At an interior node, which is modelled as a function applied to a number of arguments, producing a result, - that function's signature governs context expansion. Let T be the function's return type; then the context - is expanded to each argument expression that also is of type T. - - If a function parameter is of type S != T, a new type context is created for it. If the type S is already fixed - by the function signature, it will be the target type of the new context. - - At the tree's leaves, types are applied and checked. By the above propagation rule, all leaves that share a typing - context must have the exact same type (modulo constness). There the actual type checking happens. - If a variable is encountered and the context does not yet have a target type, it is set to the variable's type. - If a constant is encountered, it is typified using the current target type. - If no target type is known yet, the constant will first be instantiated as a DeferredTypedConstant, - and stashed in the context. - As soon as the context learns its target type, it is applied to all deferred constants. - - In addition to leaves, some interior nodes may also have to be checked against the target type. - In particular, these are array accesses, struct member accesses, and calls to functions with a fixed - return type. - - When a context is 'closed' during the recursive unwinding, it shall be an error if it still contains unresolved - constants. - - TODO: The context shall keep track of it's target type's origin to aid in producing helpful error messages. + """Apply data types to expressions. + + The Typifier will traverse the AST and apply a contextual typing scheme to figure out + the data types of all encountered expressions. + To this end, it covers each expression tree with a set of disjoint typing contexts. + All nodes covered by the same typing context must have the same type. + + Starting from an expression's root, a typing context is implicitly expanded through + the recursive descent into a node's children. In particular, a child is typified within + the same context as its parent if the node's semantics require parent and child to have + the same type (e.g. at arithmetic operators, mathematical functions, etc.). + If a node's child is required to have a different type, a new context is opened. + + For each typing context, its target type is prescribed by the first node encountered during traversal + whose type is fixed according to its typing rules. All other nodes covered by the context must share + that type. + + The types of arithmetic operators, mathematical functions, and untyped constants are + inferred from their context's target type. If one of these is encountered while no target type is set yet + in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon + as the target type is fixed. + """ def __init__(self, ctx: KernelCreationContext): @@ -152,7 +189,7 @@ class Typifier: return node def typify_expression( - self, expr: PsExpression, target_type: PsNumericType | None = None + self, expr: PsExpression, target_type: PsType | None = None ) -> tuple[PsExpression, PsType]: tc = TypeContext(target_type) self.visit_expr(expr, tc) @@ -202,25 +239,22 @@ class Typifier: def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None: """Recursive processing of expression nodes""" match expr: - case PsSymbolExpr(symb): - if symb.dtype is None: - dtype = self._ctx.default_dtype - symb.apply_dtype(dtype) - tc.apply_and_check(expr, symb.get_dtype()) - - case PsConstantExpr(constant): - if constant.dtype is not None: - tc.apply_and_check(expr, constant.get_dtype()) + case PsSymbolExpr(_): + if expr.dtype is None: + tc.apply_dtype(expr, self._ctx.default_dtype) else: - tc.typify_constant(expr) + tc.apply_dtype(expr, expr.dtype) + + case PsConstantExpr(_): + tc.infer_dtype(expr) - case PsArrayAccess(_, idx): - tc.apply_and_check(expr, expr.dtype) + case PsArrayAccess(bptr, idx): + tc.apply_dtype(expr, bptr.array.element_type) index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: - index_tc.apply_and_check(idx, self._ctx.index_dtype) + index_tc.apply_dtype(idx, self._ctx.index_dtype) elif not isinstance(index_tc.target_type, PsIntegerType): raise TypificationError( f"Array index is not of integer type: {idx} has type {index_tc.target_type}" @@ -235,17 +269,40 @@ class Typifier: "Type of subscript base is not subscriptable." ) - tc.apply_and_check(expr, arr_tc.target_type.base_type) + tc.apply_dtype(expr, arr_tc.target_type.base_type) index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: - index_tc.apply_and_check(idx, self._ctx.index_dtype) + index_tc.apply_dtype(idx, self._ctx.index_dtype) elif not isinstance(index_tc.target_type, PsIntegerType): raise TypificationError( f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" ) + case PsDeref(ptr): + ptr_tc = TypeContext() + self.visit_expr(ptr, ptr_tc) + + if not isinstance(ptr_tc.target_type, PsDereferencableType): + raise TypificationError( + "Type of argument to a Deref is not dereferencable" + ) + + tc.apply_dtype(expr, ptr_tc.target_type.base_type) + + case PsAddressOf(arg): + arg_tc = TypeContext() + self.visit_expr(arg, arg_tc) + + if arg_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of argument to AddressOf: {arg}" + ) + + ptr_type = PsPointerType(arg_tc.target_type, True) + tc.apply_dtype(expr, ptr_type) + case PsLookup(aggr, member_name): aggr_tc = TypeContext(None) self.visit_expr(aggr, aggr_tc) @@ -262,47 +319,19 @@ class Typifier: f"Aggregate of type {aggr_type} does not have a member {member}." ) - tc.apply_and_check(expr, member.dtype) - - # integer operations - case ( - PsIntDiv(op1, op2) - | PsLeftShift(op1, op2) - | PsRightShift(op1, op2) - | PsBitwiseAnd(op1, op2) - | PsBitwiseXor(op1, op2) - | PsBitwiseOr(op1, op2) - ): - if tc.target_type is not None and not isinstance( - tc.target_type, PsIntegerType - ): - raise TypificationError( - f"Integer expression used in non-integer context.\n" - f" Integer expression: {expr}\n" - f" Context type: {tc.target_type}" - ) - - self.visit_expr(op1, tc) - self.visit_expr(op2, tc) - - if tc.target_type is None: - raise TypificationError( - f"Unable to infer type of integer expression {expr}." - ) - elif not isinstance(tc.target_type, PsIntegerType): - raise TypificationError( - f"Argument(s) to integer function are non-integer in expression {expr}." - ) + tc.apply_dtype(expr, member.dtype) case PsBinOp(op1, op2): self.visit_expr(op1, tc) self.visit_expr(op2, tc) + tc.infer_dtype(expr) case PsCall(function, args): match function: case PsMathFunction(): for arg in args: self.visit_expr(arg, tc) + tc.infer_dtype(expr) case _: raise TypificationError( f"Don't know how to typify calls to {function}" @@ -329,14 +358,14 @@ class Typifier: f"{len(items)} items as {tc.target_type}" ) else: - items_tc.apply_and_check(expr, tc.target_type.base_type) + items_tc.apply_dtype(None, tc.target_type.base_type) else: arr_type = PsArrayType(items_tc.target_type, len(items)) - tc.apply_and_check(expr, arr_type) + tc.apply_dtype(expr, arr_type) - case PsCast(dtype, operand): - self.visit_expr(operand, TypeContext()) - tc.apply_and_check(expr, dtype) + case PsCast(dtype, arg): + self.visit_expr(arg, TypeContext()) + tc.apply_dtype(expr, dtype) case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 355c28d8f..0b816bf93 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -9,5 +9,5 @@ __all__ = [ "GenericVectorCpu", "X86VectorCpu", "X86VectorArch", - "GenericGpu" + "GenericGpu", ] diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 1e7d958f7..79ab6f9ec 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -3,11 +3,12 @@ from .platform import Platform from ..kernelcreation.iteration_space import ( IterationSpace, FullIterationSpace, - SparseIterationSpace, + # SparseIterationSpace, ) from ..ast.structural import PsBlock, PsConditional from ..ast.expressions import ( + PsExpression, PsSymbolExpr, PsAdd, ) @@ -17,10 +18,18 @@ from ..symbols import PsSymbol int32 = PsSignedIntegerType(width=32, const=False) -BLOCK_IDX = [PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ('x', 'y', 'z')] -THREAD_IDX = [PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ('x', 'y', 'z')] -BLOCK_DIM = [PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ('x', 'y', 'z')] -GRID_DIM = [PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ('x', 'y', 'z')] +BLOCK_IDX = [ + PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") +] +THREAD_IDX = [ + PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") +] +BLOCK_DIM = [ + PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") +] +GRID_DIM = [ + PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") +] class GenericGpu(Platform): @@ -29,7 +38,9 @@ class GenericGpu(Platform): def required_headers(self) -> set[str]: return {"gpu_defines.h"} - def materialize_iteration_space(self, body: PsBlock, ispace: IterationSpace) -> PsBlock: + def materialize_iteration_space( + self, body: PsBlock, ispace: IterationSpace + ) -> PsBlock: if isinstance(ispace, FullIterationSpace): return self._guard_full_iteration_space(body, ispace) else: @@ -37,13 +48,17 @@ class GenericGpu(Platform): def cuda_indices(self, dim): block_size = BLOCK_DIM - indices = [block_index * bs + thread_idx - for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)] + indices = [ + block_index * bs + thread_idx + for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX) + ] return indices[:dim] # Internals - def _guard_full_iteration_space(self, body: PsBlock, ispace: FullIterationSpace) -> PsBlock: + def _guard_full_iteration_space( + self, body: PsBlock, ispace: FullIterationSpace + ) -> PsBlock: dimensions = ispace.dimensions @@ -53,12 +68,14 @@ class GenericGpu(Platform): loop_order = archetype_field.layout dimensions = [dimensions[coordinate] for coordinate in loop_order] - start = [PsAdd(c, d.start) for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1])] + start = [ + PsAdd(c, d.start) + for c, d in zip(self.cuda_indices(len(dimensions)), dimensions[::-1]) + ] conditions = [PsLt(c, d.stop) for c, d in zip(start, dimensions[::-1])] - condition = conditions[0] + condition: PsExpression = conditions[0] for c in conditions[1:]: condition = PsAnd(condition, c) return PsBlock([PsConditional(condition, body)]) - diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py index c946ae7bb..03d79a689 100644 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ b/src/pystencils/backend/transformations/erase_anonymous_structs.py @@ -12,6 +12,7 @@ from ..ast.expressions import ( PsAddressOf, PsCast, ) +from ..kernelcreation import Typifier from ..arrays import PsArrayBasePointer, TypeErasedBasePointer from ...types import PsStructType, PsPointerType @@ -98,6 +99,10 @@ class EraseAnonymousStructTypes: ) type_erased_access = PsArrayAccess(type_erased_bp, byte_index) - return PsDeref( + deref = PsDeref( PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)) ) + + typify = Typifier(self._ctx) + deref = typify(deref) + return deref diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py index e587ba129..7972de069 100644 --- a/src/pystencils/backend/transformations/select_intrinsics.py +++ b/src/pystencils/backend/transformations/select_intrinsics.py @@ -68,7 +68,7 @@ class MaterializeVectorIntrinsics: match node: case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorArrayAccess): vc = VecTypeCtx() - vc.set(lhs.dtype) + vc.set(lhs.get_vector_type()) store_arg = self.visit_expr(rhs, vc) return PsStatement(self._platform.vector_store(lhs, store_arg)) case PsExpression(): @@ -95,7 +95,7 @@ class MaterializeVectorIntrinsics: return expr case PsVectorArrayAccess(): - vc.set(expr.dtype) + vc.set(expr.get_vector_type()) return self._platform.vector_load(expr) case PsBinOp(op1, op2): diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 487f77783..204ee24cf 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -32,10 +32,8 @@ def test_parsing_negative(): "const notatype * const", "cnost uint32_t", "uint45_t", - "int", # plain ints are ambiguous "float float", "double * int", - "bool", ] for spec in bad_specs: -- GitLab