diff --git a/src/pystencils/nbackend/kernelcreation/domain_kernels.py b/src/pystencils/nbackend/kernelcreation/domain_kernels.py index a66cbd12ef05e57ffc972910045aff964d7fb2a1..1ee1cfae95463b39261b75a0dc6557e9431fe203 100644 --- a/src/pystencils/nbackend/kernelcreation/domain_kernels.py +++ b/src/pystencils/nbackend/kernelcreation/domain_kernels.py @@ -9,8 +9,11 @@ from ..ast import PsBlock from .context import KernelCreationContext, FullIterationSpace from .freeze import FreezeExpressions +from .typification import Typifier # flake8: noqa + + def create_domain_kernel(assignments: AssignmentCollection): # TODO: Assemble configuration @@ -46,7 +49,8 @@ def create_domain_kernel(assignments: AssignmentCollection): # 5. Typify # Also the same for both types of kernels - # determine_types(kernel_body) + typify = Typifier(ctx) + kernel_body = typify(kernel_body) # Up to this point, all was target-agnostic, but now the target becomes relevant. # Here we might hand off the compilation to a target-specific part of the compiler diff --git a/src/pystencils/nbackend/kernelcreation/typification.py b/src/pystencils/nbackend/kernelcreation/typification.py index 4a228e9e24d64a32853d03aea18aed3050f7ea76..65263cbe8a8b8e9fecb8055976ab23295037ffca 100644 --- a/src/pystencils/nbackend/kernelcreation/typification.py +++ b/src/pystencils/nbackend/kernelcreation/typification.py @@ -1,17 +1,50 @@ +from __future__ import annotations + +from typing import TypeVar + import pymbolic.primitives as pb from pymbolic.mapper import Mapper from .context import KernelCreationContext from ..types import PsAbstractType from ..typed_expressions import PsTypedVariable +from ..ast import PsAstNode, PsExpression, PsAssignment + + +class TypificationException(Exception): + """Indicates a fatal error during typification.""" + + +NodeT = TypeVar("NodeT", bound=PsAstNode) class Typifier(Mapper): def __init__(self, ctx: KernelCreationContext): self._ctx = ctx - def __call__(self, expr: pb.Expression) -> tuple[pb.Expression, PsAbstractType]: - return self.rec(expr) + def __call__(self, node: NodeT) -> NodeT: + match node: + case PsExpression(expr): + node.expression, _ = self.rec(expr) + + case PsAssignment(lhs, rhs): + lhs, lhs_dtype = self.rec(lhs) + rhs, rhs_dtype = self.rec(rhs) + if lhs_dtype != rhs_dtype: + # todo: (optional) automatic cast insertion? + raise TypificationException( + "Mismatched types in assignment: \n" + f" {lhs} <- {rhs}\n" + f" dtype(lhs) = {lhs_dtype}\n" + f" dtype(rhs) = {rhs_dtype}\n" + ) + node.lhs = lhs + node.rhs = rhs + + case unknown: + raise NotImplementedError(f"Don't know how to typify {unknown}") + + return node def map_variable(self, var: pb.Variable) -> tuple[pb.Expression, PsAbstractType]: dtype = NotImplemented # determine variable type