From 45c9fe80649c1b1b9382909ca4641f57a75e11d1 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 30 Jan 2024 23:19:04 +0100
Subject: [PATCH] contextual typing

---
 .../nbackend/kernelcreation/freeze.py         |   8 +-
 .../nbackend/kernelcreation/typification.py   | 236 ++++++++++--------
 tests/nbackend/test_typification.py           |  30 ++-
 3 files changed, 168 insertions(+), 106 deletions(-)

diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py
index cab2ab70a..11ea24219 100644
--- a/src/pystencils/nbackend/kernelcreation/freeze.py
+++ b/src/pystencils/nbackend/kernelcreation/freeze.py
@@ -35,11 +35,11 @@ class FreezeExpressions(SympyToPymbolicMapper):
         ...
 
     @overload
-    def __call__(self, expr: Assignment) -> PsAssignment:
+    def __call__(self, expr: sp.Expr) -> PsExpression:
         ...
 
     @overload
-    def __call__(self, expr: sp.Basic) -> pb.Expression:
+    def __call__(self, expr: Assignment) -> PsAssignment:
         ...
 
     def __call__(self, obj):
@@ -47,8 +47,8 @@ class FreezeExpressions(SympyToPymbolicMapper):
             return PsBlock([self.rec(asm) for asm in obj.all_assignments])
         elif isinstance(obj, Assignment):
             return cast(PsAssignment, self.rec(obj))
-        elif isinstance(obj, sp.Basic):
-            return cast(pb.Expression, self.rec(obj))
+        elif isinstance(obj, sp.Expr):
+            return PsExpression(cast(pb.Expression, self.rec(obj)))
         else:
             raise PsInputError(f"Don't know how to freeze {obj}")
 
diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py
index 41f14431b..ff467a44d 100644
--- a/src/pystencils/nbackend/kernelcreation/typification.py
+++ b/src/pystencils/nbackend/kernelcreation/typification.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import TypeVar, Any, Sequence, cast
+from typing import TypeVar, Any, NoReturn
 
 import pymbolic.primitives as pb
 from pymbolic.mapper import Mapper
@@ -10,6 +10,9 @@ from ..types import PsAbstractType, PsNumericType, deconstify
 from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
 from ..arrays import PsArrayAccess
 from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
+from ..exceptions import PsInternalCompilerError
+
+__all__ = ["Typifier"]
 
 
 class TypificationError(Exception):
@@ -19,6 +22,72 @@ class TypificationError(Exception):
 NodeT = TypeVar("NodeT", bound=PsAstNode)
 
 
+class UndeterminedType(PsNumericType):
+    def create_constant(self, value: Any) -> Any:
+        return None
+    
+    def _err(self) -> NoReturn:
+        raise PsInternalCompilerError("Calling UndeterminedType.")
+
+    def create_literal(self, value: Any) -> str:
+        self._err()
+
+    def is_int(self) -> bool:
+        self._err()
+
+    def is_sint(self) -> bool:
+        self._err()
+
+    def is_uint(self) -> bool:
+        self._err()
+
+    def is_float(self) -> bool:
+        self._err()
+
+    def __eq__(self, other: object) -> bool:
+        self._err()
+
+    def _c_string(self) -> str:
+        self._err()
+
+
+class DeferredTypedConstant(PsTypedConstant):
+    """Special subclass for constants whose types cannot be determined yet at the time of their creation.
+
+    Outside of the typifier, a DeferredTypedConstant acts exactly the same way as a PsTypedConstant.
+    """
+
+    def __init__(self, value: Any):
+        self._value_deferred = value
+
+    def resolve(self, dtype: PsNumericType):
+        super().__init__(self._value_deferred, dtype)
+
+
+class TypeContext:
+    def __init__(self, target_type: PsNumericType | None):
+        self._target_type = deconstify(target_type) if target_type is not None else None
+        self._deferred_constants: list[DeferredTypedConstant] = []
+
+    def make_constant(self, value: Any) -> PsTypedConstant:
+        if self._target_type is None:
+            dc = DeferredTypedConstant(value)
+            self._deferred_constants.append(dc)
+            return dc
+        else:
+            return PsTypedConstant(value, self._target_type)
+
+    def apply(self, target_type: PsNumericType):
+        assert self._target_type is None, "Type context was already resolved"
+        self._target_type = deconstify(target_type)
+        for dc in self._deferred_constants:
+            dc.resolve(self._target_type)
+
+    @property
+    def target_type(self) -> PsNumericType | None:
+        return self._target_type
+
+
 class Typifier(Mapper):
     """Typifier for untyped expressions.
 
@@ -27,14 +96,33 @@ class Typifier(Mapper):
 
      - Plain variables will be assigned a type according to `ctx.options.default_dtype`.
      - Constants will be converted to typed constants by applying the target type of the current context.
-       If the target type is unknown, typification of constants will fail.
-
-    The target type for an expression must either be provided by the user or is inferred from the context.
-    The two primary contexts are an assignment, where the target type of the right-hand side expression is
-    given by the type of the left-hand side; and the index expression of an array access, where the target
-    type is given by `ctx.options.index_dtype`.
-    The target type is propagated upward through the expression tree. It is applied to all untyped constants,
-    and used to check the correctness of the types of expressions.
+
+
+    Contextual Typing
+    -----------------
+
+    Starting at an expression's root, the typifier attempts to expand a typing context as far as possible.
+    This happens implicitly during the recursive traversal of the expression tree.
+
+    At an interior node, which is modelled as a function applied to a number of arguments, producing a result,
+    that function's signature governs context expansion. Let T be the function's return type; then the context
+    is expanded to each argument expression that also is of type T.
+
+    If a function parameter is of type S != T, a new type context is created for it. If the type S is already fixed
+    by the function signature, it will be the target type of the new context.
+
+    At the tree's leaves, types are applied and checked. By the above propagation rule, all leaves that share a typing
+    context must have the exact same type (modulo constness). This type is checked at variables, and applied to
+    constants.
+
+    It may happen that the typifier arrives at a constant before the context's target type could be figured out.
+    In that case, the constant will first be instantiated as a DeferredTypedConstant, and stashed in the context.
+    As soon as the context learns its target type, it is applied to all deferred constants.
+
+    When a context is 'closed' during the recursive unwinding, it shall be an error if it still contains unresolved
+    constants.
+
+    TODO: The context shall keep track of it's target type's origin to aid in producing helpful error messages.
     """
 
     def __init__(self, ctx: KernelCreationContext):
@@ -46,18 +134,15 @@ class Typifier(Mapper):
                 node.statements = [self(s) for s in statements]
 
             case PsExpression(expr):
-                node.expression, _ = self.rec(expr)
+                node.expression = self.rec(expr, TypeContext(None))
 
             case PsAssignment(lhs, rhs):
-                new_lhs, lhs_dtype = self.rec(lhs.expression, None)
-                new_rhs, rhs_dtype = self.rec(rhs.expression, lhs_dtype)
-                if lhs_dtype != rhs_dtype:
-                    raise TypificationError(
-                        "Mismatched types in assignment: \n"
-                        f"    {lhs} <- {rhs}\n"
-                        f"    dtype(lhs) = {lhs_dtype}\n"
-                        f"    dtype(rhs) = {rhs_dtype}\n"
-                    )
+                tc = TypeContext(None)
+                #   LHS defines target type; type context carries it to RHS
+                new_lhs = self.rec(lhs.expression, tc)
+                assert tc.target_type is not None
+                new_rhs = self.rec(rhs.expression, tc)
+                
                 node.lhs.expression = new_lhs
                 node.rhs.expression = new_rhs
 
@@ -67,7 +152,7 @@ class Typifier(Mapper):
         return node
 
     """
-    def rec(self, expr: Any, target_type: PsNumericType | None)
+    def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant
 
     All visitor methods take an expression and the target type of the current context.
     They shall return the typified expression together with its type.
@@ -77,106 +162,59 @@ class Typifier(Mapper):
     def typify_expression(
         self, expr: Any, target_type: PsNumericType | None = None
     ) -> ExprOrConstant:
-        return self.rec(expr, target_type)
+        return self.rec(expr, TypeContext(target_type))
 
     #   Leaf nodes: Variables, Typed Variables, Constants and TypedConstants
 
-    def map_typed_variable(
-        self, var: PsTypedVariable, target_type: PsNumericType | None
-    ):
-        self._check_target_type(var, var.dtype, target_type)
-        return var, deconstify(var.dtype)
+    def map_typed_variable(self, var: PsTypedVariable, tc: TypeContext):
+        self._apply_target_type(var, var.dtype, tc)
+        return var
 
-    def map_variable(
-        self, var: pb.Variable, target_type: PsNumericType | None
-    ) -> tuple[PsTypedVariable, PsNumericType]:
+    def map_variable(self, var: pb.Variable, tc: TypeContext) -> PsTypedVariable:
         dtype = self._ctx.options.default_dtype
         typed_var = PsTypedVariable(var.name, dtype)
-        self._check_target_type(typed_var, dtype, target_type)
-        return typed_var, deconstify(dtype)
+        self._apply_target_type(typed_var, dtype, tc)
+        return typed_var
 
-    def map_constant(
-        self, value: Any, target_type: PsNumericType | None
-    ) -> tuple[PsTypedConstant, PsNumericType]:
+    def map_constant(self, value: Any, tc: TypeContext) -> PsTypedConstant:
         if isinstance(value, PsTypedConstant):
-            self._check_target_type(value, value.dtype, target_type)
-            return value, deconstify(value.dtype)
-        elif target_type is None:
-            raise TypificationError(
-                f"Unable to typify constant {value}: Unknown target type in this context."
-            )
-        else:
-            return PsTypedConstant(value, target_type), deconstify(target_type)
+            self._apply_target_type(value, value.dtype, tc)
+            return value
+
+        return tc.make_constant(value)
 
     #   Array Access
 
-    def map_array_access(
-        self, access: PsArrayAccess, target_type: PsNumericType | None
-    ) -> tuple[PsArrayAccess, PsNumericType]:
-        self._check_target_type(access, access.dtype, target_type)
-        index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
-        return PsArrayAccess(access.base_ptr, index), cast(
-            PsNumericType, deconstify(access.dtype)
+    def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess:
+        self._apply_target_type(access, access.dtype, tc)
+        index, _ = self.rec(
+            access.index_tuple[0], TypeContext(self._ctx.options.index_dtype)
         )
+        return PsArrayAccess(access.base_ptr, index)
 
     #   Arithmetic Expressions
 
-    def _homogenize(
-        self,
-        expr: pb.Expression,
-        args: Sequence[Any],
-        target_type: PsNumericType | None,
-    ) -> tuple[tuple[ExprOrConstant], PsNumericType]:
-        """Typify all arguments of a multi-argument expression with the same type."""
-        new_args = [None] * len(args)
-        common_type: PsNumericType | None = None
-
-        for i, c in enumerate(args):
-            new_args[i], arg_i_type = self.rec(c, target_type)
-            if common_type is None:
-                common_type = arg_i_type
-            elif common_type != arg_i_type:
-                raise TypificationError(
-                    f"Type mismatch in expression {expr}: Type of operand {i} did not match previous operands\n"
-                    f"     Previous type: {common_type}\n"
-                    f"  Operand {i} type: {arg_i_type}"
-                )
-
-        assert common_type is not None
-
-        return cast(tuple[ExprOrConstant], tuple(new_args)), common_type
-
-    def map_sum(
-        self, expr: pb.Sum, target_type: PsNumericType | None
-    ) -> tuple[pb.Sum, PsNumericType]:
-        new_args, dtype = self._homogenize(expr, expr.children, target_type)
-        return pb.Sum(new_args), dtype
-
-    def map_product(
-        self, expr: pb.Product, target_type: PsNumericType | None
-    ) -> tuple[pb.Product, PsNumericType]:
-        new_args, dtype = self._homogenize(expr, expr.children, target_type)
-        return pb.Product(new_args), dtype
-
-    def map_call(
-        self, expr: pb.Call, target_type: PsNumericType | None
-    ) -> tuple[pb.Call, PsNumericType]:
-        """
-        TODO: Figure out the best way to typify functions
+    def map_sum(self, expr: pb.Sum, tc: TypeContext) -> pb.Sum:
+        return pb.Sum(tuple(self.rec(c, tc) for c in expr.children))
 
-         - How to propagate target_type in the face of multiple overloads?
+    def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product:
+        return pb.Product(tuple(self.rec(c, tc) for c in expr.children))
+
+    def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
+        """
+        TODO: Figure out how to describe function signatures
         """
         raise NotImplementedError()
 
-    def _check_target_type(
-        self,
-        expr: ExprOrConstant,
-        expr_type: PsAbstractType,
-        target_type: PsNumericType | None,
+    def _apply_target_type(
+        self, expr: ExprOrConstant, expr_type: PsAbstractType, tc: TypeContext
     ):
-        if target_type is not None and deconstify(expr_type) != deconstify(target_type):
+        if tc.target_type is None:
+            assert isinstance(expr_type, PsNumericType)
+            tc.apply(expr_type)
+        elif deconstify(expr_type) != tc.target_type:
             raise TypificationError(
                 f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
                 f"  Expression type: {expr_type}\n"
-                f"      Target type: {target_type}"
+                f"      Target type: {tc.target_type}"
             )
diff --git a/tests/nbackend/test_typification.py b/tests/nbackend/test_typification.py
index f9e8ab517..6caadb084 100644
--- a/tests/nbackend/test_typification.py
+++ b/tests/nbackend/test_typification.py
@@ -35,12 +35,36 @@ def test_typify_simple():
             case PsTypedVariable(name, dtype):
                 assert name in "xyz"
                 assert dtype == ctx.options.default_dtype
-            case pb.Variable:
-                pytest.fail("Encountered untyped variable")
             case pb.Sum(cs) | pb.Product(cs):
                 [check(c) for c in cs]
             case _:
-                pytest.fail("Non-exhaustive pattern matcher.")
+                pytest.fail(f"Unexpected expression: {expr}")
 
     check(fasm.lhs.expression)
     check(fasm.rhs.expression)
+
+
+def test_contextual_typing():
+    options = KernelCreationOptions()
+    ctx = KernelCreationContext(options)
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+    expr = freeze(2 * x + 3 * y + z - 4)
+    expr = typify(expr)
+
+    def check(expr):
+        match expr:
+            case PsTypedConstant(value, dtype):
+                assert value in (2, 3, -4)
+                assert dtype == constify(ctx.options.default_dtype)
+            case PsTypedVariable(name, dtype):
+                assert name in "xyz"
+                assert dtype == ctx.options.default_dtype
+            case pb.Sum(cs) | pb.Product(cs):
+                [check(c) for c in cs]
+            case _:
+                pytest.fail(f"Unexpected expression: {expr}")
+
+    check(expr.expression)
-- 
GitLab