diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py index 3027c0b4096ab0a6bd8578613f2f5acd4e53d1bb..49375209ee3e6feaf9bffd6e78790e65207f5df7 100644 --- a/src/pystencils/nbackend/arrays.py +++ b/src/pystencils/nbackend/arrays.py @@ -42,6 +42,7 @@ kernel_function.add_constraints(*constraints) from __future__ import annotations +from sys import intern from types import EllipsisType @@ -240,6 +241,9 @@ class PsArrayStrideVar(PsArrayAssocVar): class PsArrayAccess(pb.Subscript): + + mapper_method = intern("map_array_access") + def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant): super(PsArrayAccess, self).__init__(base_ptr, index) self._base_ptr = base_ptr diff --git a/src/pystencils/nbackend/kernelcreation/options.py b/src/pystencils/nbackend/kernelcreation/options.py index c9efd83c489eb9627bbf8883dcc1cd59e18adbca..355050b7ec24b26262c97ebdb6864496024a1477 100644 --- a/src/pystencils/nbackend/kernelcreation/options.py +++ b/src/pystencils/nbackend/kernelcreation/options.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from ...enums import Target from ..exceptions import PsOptionsError -from ..types import PsIntegerType +from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType from .defaults import Sympy as SpDefaults @@ -43,9 +43,17 @@ class KernelCreationOptions: TODO: Specification of valid slices and their behaviour """ + """Data Types""" + index_dtype: PsIntegerType = SpDefaults.index_dtype """Data type used for all index calculations.""" + default_dtype: PsNumericType = PsIeeeFloatType(64) + """Default numeric data type. + + This data type will be applied to all untyped symbols. + """ + def __post_init__(self): if self.iteration_slice is not None and self.ghost_layers is not None: raise PsOptionsError( diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index 65263cbe8a8b8e9fecb8055976ab23295037ffca..a342136238d5f2d736da18423e736a8c02305137 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -1,17 +1,18 @@ from __future__ import annotations -from typing import TypeVar +from typing import TypeVar, Any, Sequence, cast import pymbolic.primitives as pb from pymbolic.mapper import Mapper from .context import KernelCreationContext -from ..types import PsAbstractType -from ..typed_expressions import PsTypedVariable +from ..types import PsAbstractType, PsNumericType +from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant +from ..arrays import PsArrayAccess from ..ast import PsAstNode, PsExpression, PsAssignment -class TypificationException(Exception): +class TypificationError(Exception): """Indicates a fatal error during typification.""" @@ -19,6 +20,23 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class Typifier(Mapper): + """Typifier for untyped expressions. + + The typifier, when called with an AST node, will attempt to figure out + the types for all untyped expressions within the node: + + - Plain variables will be assigned a type according to `ctx.options.default_dtype`. + - Constants will be converted to typed constants by applying the target type of the current context. + If the target type is unknown, typification of constants will fail. + + The target type for an expression must either be provided by the user or is inferred from the context. + The two primary contexts are an assignment, where the target type of the right-hand side expression is + given by the type of the left-hand side; and the index expression of an array access, where the target + type is given by `ctx.options.index_dtype`. + The target type is propagated upward through the expression tree. It is applied to all untyped constants, + and used to check the correctness of the types of expressions. + """ + def __init__(self, ctx: KernelCreationContext): self._ctx = ctx @@ -28,24 +46,117 @@ class Typifier(Mapper): node.expression, _ = self.rec(expr) case PsAssignment(lhs, rhs): - lhs, lhs_dtype = self.rec(lhs) - rhs, rhs_dtype = self.rec(rhs) + new_lhs, lhs_dtype = self.rec(lhs.expression, None) + new_rhs, rhs_dtype = self.rec(rhs.expression, lhs_dtype) if lhs_dtype != rhs_dtype: # todo: (optional) automatic cast insertion? - raise TypificationException( + raise TypificationError( "Mismatched types in assignment: \n" f" {lhs} <- {rhs}\n" f" dtype(lhs) = {lhs_dtype}\n" f" dtype(rhs) = {rhs_dtype}\n" ) - node.lhs = lhs - node.rhs = rhs + node.lhs.expression = new_lhs + node.rhs.expression = new_rhs case unknown: raise NotImplementedError(f"Don't know how to typify {unknown}") - + return node - def map_variable(self, var: pb.Variable) -> tuple[pb.Expression, PsAbstractType]: - dtype = NotImplemented # determine variable type - return PsTypedVariable(var.name, dtype), dtype + # def rec(self, expr: Any, target_type: PsNumericType | None) + + def typify_expression( + self, expr: Any, target_type: PsNumericType | None = None + ) -> ExprOrConstant: + return self.rec(expr, target_type) + + # Leaf nodes: Variables, Typed Variables, Constants and TypedConstants + + def map_typed_variable( + self, var: PsTypedVariable, target_type: PsNumericType | None + ): + self._check_target_type(var, var.dtype, target_type) + return var, var.dtype + + def map_variable( + self, var: pb.Variable, target_type: PsNumericType | None + ) -> tuple[PsTypedVariable, PsNumericType]: + dtype = self._ctx.options.default_dtype + typed_var = PsTypedVariable(var.name, dtype) + self._check_target_type(typed_var, dtype, target_type) + return typed_var, dtype + + def map_constant( + self, value: Any, target_type: PsNumericType | None + ) -> tuple[PsTypedConstant, PsNumericType]: + if isinstance(value, PsTypedConstant): + self._check_target_type(value, value.dtype, target_type) + return value, value.dtype + elif target_type is None: + raise TypificationError( + f"Unable to typify constant {value}: Unknown target type in this context." + ) + else: + return PsTypedConstant(value, target_type), target_type + + # Array Access + + def map_array_access( + self, access: PsArrayAccess, target_type: PsNumericType | None + ) -> tuple[PsArrayAccess, PsNumericType]: + self._check_target_type(access, access.array.element_type, target_type) + index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype) + return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, access.array.element_type) + + # Arithmetic Expressions + + def _homogenize( + self, + expr: pb.Expression, + args: Sequence[Any], + target_type: PsNumericType | None, + ) -> tuple[Sequence[ExprOrConstant], PsNumericType]: + """Typify all arguments of a multi-argument expression with the same type.""" + new_args = [None] * len(args) + common_type: PsNumericType | None = None + + for i, c in enumerate(args): + new_args[i], arg_i_type = self.rec(c, target_type) + if common_type is None: + common_type = arg_i_type + elif common_type != arg_i_type: + raise TypificationError( + f"Type mismatch in expression {expr}: Type of operand {i} did not match previous operands\n" + f" Previous type: {common_type}\n" + f" Operand {i} type: {arg_i_type}" + ) + + assert common_type is not None + + return cast(Sequence[ExprOrConstant], new_args), common_type + + def map_sum( + self, expr: pb.Sum, target_type: PsNumericType | None + ) -> tuple[pb.Sum, PsNumericType]: + new_args, dtype = self._homogenize(expr, expr.children, target_type) + return pb.Sum(new_args), dtype + + def map_product( + self, expr: pb.Product, target_type: PsNumericType | None + ) -> tuple[pb.Product, PsNumericType]: + new_args, dtype = self._homogenize(expr, expr.children, target_type) + return pb.Product(new_args), dtype + + def _check_target_type( + self, + expr: ExprOrConstant, + expr_type: PsAbstractType, + target_type: PsNumericType | None, + ): + if target_type is not None and expr_type != target_type: + raise TypificationError( + f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" + f" Expression type: {expr_type}\n" + f" Target type: {target_type}" + )