From 45c9fe80649c1b1b9382909ca4641f57a75e11d1 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 30 Jan 2024 23:19:04 +0100 Subject: [PATCH] contextual typing --- .../nbackend/kernelcreation/freeze.py | 8 +- .../nbackend/kernelcreation/typification.py | 236 ++++++++++-------- tests/nbackend/test_typification.py | 30 ++- 3 files changed, 168 insertions(+), 106 deletions(-) diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py index cab2ab70a..11ea24219 100644 --- a/src/pystencils/nbackend/kernelcreation/freeze.py +++ b/src/pystencils/nbackend/kernelcreation/freeze.py @@ -35,11 +35,11 @@ class FreezeExpressions(SympyToPymbolicMapper): ... @overload - def __call__(self, expr: Assignment) -> PsAssignment: + def __call__(self, expr: sp.Expr) -> PsExpression: ... @overload - def __call__(self, expr: sp.Basic) -> pb.Expression: + def __call__(self, expr: Assignment) -> PsAssignment: ... def __call__(self, obj): @@ -47,8 +47,8 @@ class FreezeExpressions(SympyToPymbolicMapper): return PsBlock([self.rec(asm) for asm in obj.all_assignments]) elif isinstance(obj, Assignment): return cast(PsAssignment, self.rec(obj)) - elif isinstance(obj, sp.Basic): - return cast(pb.Expression, self.rec(obj)) + elif isinstance(obj, sp.Expr): + return PsExpression(cast(pb.Expression, self.rec(obj))) else: raise PsInputError(f"Don't know how to freeze {obj}") diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index 41f14431b..ff467a44d 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TypeVar, Any, Sequence, cast +from typing import TypeVar, Any, NoReturn import pymbolic.primitives as pb from pymbolic.mapper import Mapper @@ -10,6 +10,9 @@ from ..types import PsAbstractType, PsNumericType, deconstify from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant from ..arrays import PsArrayAccess from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment +from ..exceptions import PsInternalCompilerError + +__all__ = ["Typifier"] class TypificationError(Exception): @@ -19,6 +22,72 @@ class TypificationError(Exception): NodeT = TypeVar("NodeT", bound=PsAstNode) +class UndeterminedType(PsNumericType): + def create_constant(self, value: Any) -> Any: + return None + + def _err(self) -> NoReturn: + raise PsInternalCompilerError("Calling UndeterminedType.") + + def create_literal(self, value: Any) -> str: + self._err() + + def is_int(self) -> bool: + self._err() + + def is_sint(self) -> bool: + self._err() + + def is_uint(self) -> bool: + self._err() + + def is_float(self) -> bool: + self._err() + + def __eq__(self, other: object) -> bool: + self._err() + + def _c_string(self) -> str: + self._err() + + +class DeferredTypedConstant(PsTypedConstant): + """Special subclass for constants whose types cannot be determined yet at the time of their creation. + + Outside of the typifier, a DeferredTypedConstant acts exactly the same way as a PsTypedConstant. + """ + + def __init__(self, value: Any): + self._value_deferred = value + + def resolve(self, dtype: PsNumericType): + super().__init__(self._value_deferred, dtype) + + +class TypeContext: + def __init__(self, target_type: PsNumericType | None): + self._target_type = deconstify(target_type) if target_type is not None else None + self._deferred_constants: list[DeferredTypedConstant] = [] + + def make_constant(self, value: Any) -> PsTypedConstant: + if self._target_type is None: + dc = DeferredTypedConstant(value) + self._deferred_constants.append(dc) + return dc + else: + return PsTypedConstant(value, self._target_type) + + def apply(self, target_type: PsNumericType): + assert self._target_type is None, "Type context was already resolved" + self._target_type = deconstify(target_type) + for dc in self._deferred_constants: + dc.resolve(self._target_type) + + @property + def target_type(self) -> PsNumericType | None: + return self._target_type + + class Typifier(Mapper): """Typifier for untyped expressions. @@ -27,14 +96,33 @@ class Typifier(Mapper): - Plain variables will be assigned a type according to `ctx.options.default_dtype`. - Constants will be converted to typed constants by applying the target type of the current context. - If the target type is unknown, typification of constants will fail. - - The target type for an expression must either be provided by the user or is inferred from the context. - The two primary contexts are an assignment, where the target type of the right-hand side expression is - given by the type of the left-hand side; and the index expression of an array access, where the target - type is given by `ctx.options.index_dtype`. - The target type is propagated upward through the expression tree. It is applied to all untyped constants, - and used to check the correctness of the types of expressions. + + + Contextual Typing + ----------------- + + Starting at an expression's root, the typifier attempts to expand a typing context as far as possible. + This happens implicitly during the recursive traversal of the expression tree. + + At an interior node, which is modelled as a function applied to a number of arguments, producing a result, + that function's signature governs context expansion. Let T be the function's return type; then the context + is expanded to each argument expression that also is of type T. + + If a function parameter is of type S != T, a new type context is created for it. If the type S is already fixed + by the function signature, it will be the target type of the new context. + + At the tree's leaves, types are applied and checked. By the above propagation rule, all leaves that share a typing + context must have the exact same type (modulo constness). This type is checked at variables, and applied to + constants. + + It may happen that the typifier arrives at a constant before the context's target type could be figured out. + In that case, the constant will first be instantiated as a DeferredTypedConstant, and stashed in the context. + As soon as the context learns its target type, it is applied to all deferred constants. + + When a context is 'closed' during the recursive unwinding, it shall be an error if it still contains unresolved + constants. + + TODO: The context shall keep track of it's target type's origin to aid in producing helpful error messages. """ def __init__(self, ctx: KernelCreationContext): @@ -46,18 +134,15 @@ class Typifier(Mapper): node.statements = [self(s) for s in statements] case PsExpression(expr): - node.expression, _ = self.rec(expr) + node.expression = self.rec(expr, TypeContext(None)) case PsAssignment(lhs, rhs): - new_lhs, lhs_dtype = self.rec(lhs.expression, None) - new_rhs, rhs_dtype = self.rec(rhs.expression, lhs_dtype) - if lhs_dtype != rhs_dtype: - raise TypificationError( - "Mismatched types in assignment: \n" - f" {lhs} <- {rhs}\n" - f" dtype(lhs) = {lhs_dtype}\n" - f" dtype(rhs) = {rhs_dtype}\n" - ) + tc = TypeContext(None) + # LHS defines target type; type context carries it to RHS + new_lhs = self.rec(lhs.expression, tc) + assert tc.target_type is not None + new_rhs = self.rec(rhs.expression, tc) + node.lhs.expression = new_lhs node.rhs.expression = new_rhs @@ -67,7 +152,7 @@ class Typifier(Mapper): return node """ - def rec(self, expr: Any, target_type: PsNumericType | None) + def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant All visitor methods take an expression and the target type of the current context. They shall return the typified expression together with its type. @@ -77,106 +162,59 @@ class Typifier(Mapper): def typify_expression( self, expr: Any, target_type: PsNumericType | None = None ) -> ExprOrConstant: - return self.rec(expr, target_type) + return self.rec(expr, TypeContext(target_type)) # Leaf nodes: Variables, Typed Variables, Constants and TypedConstants - def map_typed_variable( - self, var: PsTypedVariable, target_type: PsNumericType | None - ): - self._check_target_type(var, var.dtype, target_type) - return var, deconstify(var.dtype) + def map_typed_variable(self, var: PsTypedVariable, tc: TypeContext): + self._apply_target_type(var, var.dtype, tc) + return var - def map_variable( - self, var: pb.Variable, target_type: PsNumericType | None - ) -> tuple[PsTypedVariable, PsNumericType]: + def map_variable(self, var: pb.Variable, tc: TypeContext) -> PsTypedVariable: dtype = self._ctx.options.default_dtype typed_var = PsTypedVariable(var.name, dtype) - self._check_target_type(typed_var, dtype, target_type) - return typed_var, deconstify(dtype) + self._apply_target_type(typed_var, dtype, tc) + return typed_var - def map_constant( - self, value: Any, target_type: PsNumericType | None - ) -> tuple[PsTypedConstant, PsNumericType]: + def map_constant(self, value: Any, tc: TypeContext) -> PsTypedConstant: if isinstance(value, PsTypedConstant): - self._check_target_type(value, value.dtype, target_type) - return value, deconstify(value.dtype) - elif target_type is None: - raise TypificationError( - f"Unable to typify constant {value}: Unknown target type in this context." - ) - else: - return PsTypedConstant(value, target_type), deconstify(target_type) + self._apply_target_type(value, value.dtype, tc) + return value + + return tc.make_constant(value) # Array Access - def map_array_access( - self, access: PsArrayAccess, target_type: PsNumericType | None - ) -> tuple[PsArrayAccess, PsNumericType]: - self._check_target_type(access, access.dtype, target_type) - index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype) - return PsArrayAccess(access.base_ptr, index), cast( - PsNumericType, deconstify(access.dtype) + def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess: + self._apply_target_type(access, access.dtype, tc) + index, _ = self.rec( + access.index_tuple[0], TypeContext(self._ctx.options.index_dtype) ) + return PsArrayAccess(access.base_ptr, index) # Arithmetic Expressions - def _homogenize( - self, - expr: pb.Expression, - args: Sequence[Any], - target_type: PsNumericType | None, - ) -> tuple[tuple[ExprOrConstant], PsNumericType]: - """Typify all arguments of a multi-argument expression with the same type.""" - new_args = [None] * len(args) - common_type: PsNumericType | None = None - - for i, c in enumerate(args): - new_args[i], arg_i_type = self.rec(c, target_type) - if common_type is None: - common_type = arg_i_type - elif common_type != arg_i_type: - raise TypificationError( - f"Type mismatch in expression {expr}: Type of operand {i} did not match previous operands\n" - f" Previous type: {common_type}\n" - f" Operand {i} type: {arg_i_type}" - ) - - assert common_type is not None - - return cast(tuple[ExprOrConstant], tuple(new_args)), common_type - - def map_sum( - self, expr: pb.Sum, target_type: PsNumericType | None - ) -> tuple[pb.Sum, PsNumericType]: - new_args, dtype = self._homogenize(expr, expr.children, target_type) - return pb.Sum(new_args), dtype - - def map_product( - self, expr: pb.Product, target_type: PsNumericType | None - ) -> tuple[pb.Product, PsNumericType]: - new_args, dtype = self._homogenize(expr, expr.children, target_type) - return pb.Product(new_args), dtype - - def map_call( - self, expr: pb.Call, target_type: PsNumericType | None - ) -> tuple[pb.Call, PsNumericType]: - """ - TODO: Figure out the best way to typify functions + def map_sum(self, expr: pb.Sum, tc: TypeContext) -> pb.Sum: + return pb.Sum(tuple(self.rec(c, tc) for c in expr.children)) - - How to propagate target_type in the face of multiple overloads? + def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product: + return pb.Product(tuple(self.rec(c, tc) for c in expr.children)) + + def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call: + """ + TODO: Figure out how to describe function signatures """ raise NotImplementedError() - def _check_target_type( - self, - expr: ExprOrConstant, - expr_type: PsAbstractType, - target_type: PsNumericType | None, + def _apply_target_type( + self, expr: ExprOrConstant, expr_type: PsAbstractType, tc: TypeContext ): - if target_type is not None and deconstify(expr_type) != deconstify(target_type): + if tc.target_type is None: + assert isinstance(expr_type, PsNumericType) + tc.apply(expr_type) + elif deconstify(expr_type) != tc.target_type: raise TypificationError( f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" f" Expression type: {expr_type}\n" - f" Target type: {target_type}" + f" Target type: {tc.target_type}" ) diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/test_typification.py index f9e8ab517..6caadb084 100644 --- a/tests/nbackend/test_typification.py +++ b/tests/nbackend/test_typification.py @@ -35,12 +35,36 @@ def test_typify_simple(): case PsTypedVariable(name, dtype): assert name in "xyz" assert dtype == ctx.options.default_dtype - case pb.Variable: - pytest.fail("Encountered untyped variable") case pb.Sum(cs) | pb.Product(cs): [check(c) for c in cs] case _: - pytest.fail("Non-exhaustive pattern matcher.") + pytest.fail(f"Unexpected expression: {expr}") check(fasm.lhs.expression) check(fasm.rhs.expression) + + +def test_contextual_typing(): + options = KernelCreationOptions() + ctx = KernelCreationContext(options) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + expr = freeze(2 * x + 3 * y + z - 4) + expr = typify(expr) + + def check(expr): + match expr: + case PsTypedConstant(value, dtype): + assert value in (2, 3, -4) + assert dtype == constify(ctx.options.default_dtype) + case PsTypedVariable(name, dtype): + assert name in "xyz" + assert dtype == ctx.options.default_dtype + case pb.Sum(cs) | pb.Product(cs): + [check(c) for c in cs] + case _: + pytest.fail(f"Unexpected expression: {expr}") + + check(expr.expression) -- GitLab