Skip to content
Snippets Groups Projects
Commit a3843b10 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

slightly extended typifier

parent 03affb6c
No related merge requests found
Pipeline #60440 failed with stages
in 8 minutes and 42 seconds
......@@ -9,8 +9,11 @@ from ..ast import PsBlock
from .context import KernelCreationContext, FullIterationSpace
from .freeze import FreezeExpressions
from .typification import Typifier
# flake8: noqa
def create_domain_kernel(assignments: AssignmentCollection):
# TODO: Assemble configuration
......@@ -46,7 +49,8 @@ def create_domain_kernel(assignments: AssignmentCollection):
# 5. Typify
# Also the same for both types of kernels
# determine_types(kernel_body)
typify = Typifier(ctx)
kernel_body = typify(kernel_body)
# Up to this point, all was target-agnostic, but now the target becomes relevant.
# Here we might hand off the compilation to a target-specific part of the compiler
......
from __future__ import annotations
from typing import TypeVar
import pymbolic.primitives as pb
from pymbolic.mapper import Mapper
from .context import KernelCreationContext
from ..types import PsAbstractType
from ..typed_expressions import PsTypedVariable
from ..ast import PsAstNode, PsExpression, PsAssignment
class TypificationException(Exception):
"""Indicates a fatal error during typification."""
NodeT = TypeVar("NodeT", bound=PsAstNode)
class Typifier(Mapper):
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
def __call__(self, expr: pb.Expression) -> tuple[pb.Expression, PsAbstractType]:
return self.rec(expr)
def __call__(self, node: NodeT) -> NodeT:
match node:
case PsExpression(expr):
node.expression, _ = self.rec(expr)
case PsAssignment(lhs, rhs):
lhs, lhs_dtype = self.rec(lhs)
rhs, rhs_dtype = self.rec(rhs)
if lhs_dtype != rhs_dtype:
# todo: (optional) automatic cast insertion?
raise TypificationException(
"Mismatched types in assignment: \n"
f" {lhs} <- {rhs}\n"
f" dtype(lhs) = {lhs_dtype}\n"
f" dtype(rhs) = {rhs_dtype}\n"
)
node.lhs = lhs
node.rhs = rhs
case unknown:
raise NotImplementedError(f"Don't know how to typify {unknown}")
return node
def map_variable(self, var: pb.Variable) -> tuple[pb.Expression, PsAbstractType]:
dtype = NotImplemented # determine variable type
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment