Skip to content
Snippets Groups Projects
Commit a3843b10 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

slightly extended typifier

parent 03affb6c
No related merge requests found
Pipeline #60440 failed with stages
in 8 minutes and 42 seconds
...@@ -9,8 +9,11 @@ from ..ast import PsBlock ...@@ -9,8 +9,11 @@ from ..ast import PsBlock
from .context import KernelCreationContext, FullIterationSpace from .context import KernelCreationContext, FullIterationSpace
from .freeze import FreezeExpressions from .freeze import FreezeExpressions
from .typification import Typifier
# flake8: noqa # flake8: noqa
def create_domain_kernel(assignments: AssignmentCollection): def create_domain_kernel(assignments: AssignmentCollection):
# TODO: Assemble configuration # TODO: Assemble configuration
...@@ -46,7 +49,8 @@ def create_domain_kernel(assignments: AssignmentCollection): ...@@ -46,7 +49,8 @@ def create_domain_kernel(assignments: AssignmentCollection):
# 5. Typify # 5. Typify
# Also the same for both types of kernels # Also the same for both types of kernels
# determine_types(kernel_body) typify = Typifier(ctx)
kernel_body = typify(kernel_body)
# Up to this point, all was target-agnostic, but now the target becomes relevant. # Up to this point, all was target-agnostic, but now the target becomes relevant.
# Here we might hand off the compilation to a target-specific part of the compiler # Here we might hand off the compilation to a target-specific part of the compiler
......
from __future__ import annotations
from typing import TypeVar
import pymbolic.primitives as pb import pymbolic.primitives as pb
from pymbolic.mapper import Mapper from pymbolic.mapper import Mapper
from .context import KernelCreationContext from .context import KernelCreationContext
from ..types import PsAbstractType from ..types import PsAbstractType
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable
from ..ast import PsAstNode, PsExpression, PsAssignment
class TypificationException(Exception):
"""Indicates a fatal error during typification."""
NodeT = TypeVar("NodeT", bound=PsAstNode)
class Typifier(Mapper): class Typifier(Mapper):
def __init__(self, ctx: KernelCreationContext): def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx self._ctx = ctx
def __call__(self, expr: pb.Expression) -> tuple[pb.Expression, PsAbstractType]: def __call__(self, node: NodeT) -> NodeT:
return self.rec(expr) match node:
case PsExpression(expr):
node.expression, _ = self.rec(expr)
case PsAssignment(lhs, rhs):
lhs, lhs_dtype = self.rec(lhs)
rhs, rhs_dtype = self.rec(rhs)
if lhs_dtype != rhs_dtype:
# todo: (optional) automatic cast insertion?
raise TypificationException(
"Mismatched types in assignment: \n"
f" {lhs} <- {rhs}\n"
f" dtype(lhs) = {lhs_dtype}\n"
f" dtype(rhs) = {rhs_dtype}\n"
)
node.lhs = lhs
node.rhs = rhs
case unknown:
raise NotImplementedError(f"Don't know how to typify {unknown}")
return node
def map_variable(self, var: pb.Variable) -> tuple[pb.Expression, PsAbstractType]: def map_variable(self, var: pb.Variable) -> tuple[pb.Expression, PsAbstractType]:
dtype = NotImplemented # determine variable type dtype = NotImplemented # determine variable type
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment