Skip to content
Snippets Groups Projects
Commit f721e167 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

literal printing and header collection

parent a3843b10
No related merge requests found
Pipeline #60466 canceled with stages
from typing import cast from typing import cast, Any
from functools import reduce
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
from pymbolic.mapper import Collector
from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.dependency import DependencyMapper
from .kernelfunction import PsKernelFunction from .kernelfunction import PsKernelFunction
from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable, PsTypedConstant
from ..exceptions import PsMalformedAstException, PsInternalCompilerError from ..exceptions import PsMalformedAstException, PsInternalCompilerError
...@@ -24,12 +27,12 @@ class UndefinedVariablesCollector: ...@@ -24,12 +27,12 @@ class UndefinedVariablesCollector:
include_cses=False, include_cses=False,
) )
def collect(self, node: PsAstNode) -> set[PsTypedVariable]: def __call__(self, node: PsAstNode) -> set[PsTypedVariable]:
"""Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage.""" """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
match node: match node:
case PsKernelFunction(block): case PsKernelFunction(block):
return self.collect(block) return self(block)
case PsExpression(expr): case PsExpression(expr):
variables: set[Variable] = self._pb_dep_mapper(expr) variables: set[Variable] = self._pb_dep_mapper(expr)
...@@ -43,22 +46,22 @@ class UndefinedVariablesCollector: ...@@ -43,22 +46,22 @@ class UndefinedVariablesCollector:
return cast(set[PsTypedVariable], variables) return cast(set[PsTypedVariable], variables)
case PsAssignment(lhs, rhs): case PsAssignment(lhs, rhs):
return self.collect(lhs) | self.collect(rhs) return self(lhs) | self(rhs)
case PsBlock(statements): case PsBlock(statements):
undefined_vars: set[PsTypedVariable] = set() undefined_vars: set[PsTypedVariable] = set()
for stmt in statements[::-1]: for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt) undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt) undefined_vars |= self(stmt)
return undefined_vars return undefined_vars
case PsLoop(ctr, start, stop, step, body): case PsLoop(ctr, start, stop, step, body):
undefined_vars = ( undefined_vars = (
self.collect(start) self(start)
| self.collect(stop) | self(stop)
| self.collect(step) | self(step)
| self.collect(body) | self(body)
) )
undefined_vars.remove(ctr.symbol) undefined_vars.remove(ctr.symbol)
return undefined_vars return undefined_vars
...@@ -82,3 +85,34 @@ class UndefinedVariablesCollector: ...@@ -82,3 +85,34 @@ class UndefinedVariablesCollector:
raise PsInternalCompilerError( raise PsInternalCompilerError(
f"Don't know how to collect declared variables from {unknown}" f"Don't know how to collect declared variables from {unknown}"
) )
def collect_undefined_variables(node: PsAstNode) -> set[PsTypedVariable]:
return UndefinedVariablesCollector()(node)
class RequiredHeadersCollector(Collector):
"""Collect all header files required by a given AST for correct compilation.
Required headers can currently only be defined in subclasses of `PsAbstractType`.
"""
def __call__(self, node: PsAstNode) -> set[str]:
match node:
case PsExpression(expr):
return self.rec(expr)
case node:
return reduce(set.union, (self(c) for c in node.children()), set())
def map_typed_variable(self, var: PsTypedVariable) -> set[str]:
return var.dtype.required_headers
def map_constant(self, cst: Any):
if not isinstance(cst, PsTypedConstant):
raise PsMalformedAstException("Untyped constant encountered in expression.")
return cst.dtype.required_headers
def collect_required_headers(node: PsAstNode) -> set[str]:
return RequiredHeadersCollector()(node)
...@@ -129,9 +129,9 @@ class PsKernelFunction(PsAstNode): ...@@ -129,9 +129,9 @@ class PsKernelFunction(PsAstNode):
This function performs a full traversal of the AST. This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary. To improve performance, make sure to cache the result if necessary.
""" """
from .analysis import UndefinedVariablesCollector from .collectors import collect_undefined_variables
params_set = UndefinedVariablesCollector().collect(self) params_set = collect_undefined_variables(self)
params_list = sorted(params_set, key=lambda p: p.name) params_list = sorted(params_set, key=lambda p: p.name)
arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer)) arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
...@@ -140,5 +140,6 @@ class PsKernelFunction(PsAstNode): ...@@ -140,5 +140,6 @@ class PsKernelFunction(PsAstNode):
) )
def get_required_headers(self) -> set[str]: def get_required_headers(self) -> set[str]:
# TODO: Headers from types, vectorizer, ... # To Do: Headers from target/instruction set/...
return set() from .collectors import collect_required_headers
return collect_required_headers(self)
from __future__ import annotations from __future__ import annotations
from typing import TypeAlias, Any from typing import TypeAlias, Any
from sys import intern
import pymbolic.primitives as pb import pymbolic.primitives as pb
...@@ -16,6 +17,7 @@ class PsTypedVariable(pb.Variable): ...@@ -16,6 +17,7 @@ class PsTypedVariable(pb.Variable):
init_arg_names: tuple[str, ...] = ("name", "dtype") init_arg_names: tuple[str, ...] = ("name", "dtype")
__match_args__ = ("name", "dtype") __match_args__ = ("name", "dtype")
mapper_method = intern("map_typed_variable")
def __init__(self, name: str, dtype: PsAbstractType): def __init__(self, name: str, dtype: PsAbstractType):
super(PsTypedVariable, self).__init__(name) super(PsTypedVariable, self).__init__(name)
...@@ -98,8 +100,12 @@ class PsTypedConstant: ...@@ -98,8 +100,12 @@ class PsTypedConstant:
self._dtype = constify(dtype) self._dtype = constify(dtype)
self._value = self._dtype.create_constant(value) self._value = self._dtype.create_constant(value)
@property
def dtype(self) -> PsNumericType:
return self._dtype
def __str__(self) -> str: def __str__(self) -> str:
return str(self._value) return self._dtype.create_literal(self._value)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )" return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
......
...@@ -32,6 +32,15 @@ class PsAbstractType(ABC): ...@@ -32,6 +32,15 @@ class PsAbstractType(ABC):
def const(self) -> bool: def const(self) -> bool:
return self._const return self._const
# -------------------------------------------------------------------------------------------
# Optional Info
# -------------------------------------------------------------------------------------------
@property
def required_headers(self) -> set[str]:
"""The set of header files required when this type occurs in generated code."""
return set()
# ------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------
# Internal virtual operations # Internal virtual operations
# ------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------
...@@ -154,6 +163,14 @@ class PsNumericType(PsAbstractType, ABC): ...@@ -154,6 +163,14 @@ class PsNumericType(PsAbstractType, ABC):
PsTypeError: If the given value cannot be interpreted in this type. PsTypeError: If the given value cannot be interpreted in this type.
""" """
@abstractmethod
def create_literal(self, value: Any) -> str:
"""Create a C numerical literal for a constant of this type.
Raises:
PsTypeError: If the given value's type is not the numeric type's compiler-internal representation.
"""
@abstractmethod @abstractmethod
def is_int(self) -> bool: def is_int(self) -> bool:
... ...
...@@ -185,7 +202,7 @@ class PsScalarType(PsNumericType, ABC): ...@@ -185,7 +202,7 @@ class PsScalarType(PsNumericType, ABC):
def is_float(self) -> bool: def is_float(self) -> bool:
return isinstance(self, PsIeeeFloatType) return isinstance(self, PsIeeeFloatType)
@property @property
@abstractmethod @abstractmethod
def itemsize(self) -> int: def itemsize(self) -> int:
...@@ -202,6 +219,7 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -202,6 +219,7 @@ class PsIntegerType(PsScalarType, ABC):
__match_args__ = ("width",) __match_args__ = ("width",)
SUPPORTED_WIDTHS = (8, 16, 32, 64) SUPPORTED_WIDTHS = (8, 16, 32, 64)
NUMPY_TYPES: dict[int, type] = dict()
def __init__(self, width: int, signed: bool = True, const: bool = False): def __init__(self, width: int, signed: bool = True, const: bool = False):
if width not in self.SUPPORTED_WIDTHS: if width not in self.SUPPORTED_WIDTHS:
...@@ -221,11 +239,19 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -221,11 +239,19 @@ class PsIntegerType(PsScalarType, ABC):
@property @property
def signed(self) -> bool: def signed(self) -> bool:
return self._signed return self._signed
@property @property
def itemsize(self) -> int: def itemsize(self) -> int:
return self.width // 8 return self.width // 8
def create_literal(self, value: Any) -> str:
np_dtype = self.NUMPY_TYPES[self._width]
if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
unsigned_suffix = "" if self.signed else "u"
# TODO: cast literal to correct type?
return str(value) + unsigned_suffix
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIntegerType): if not isinstance(other, PsIntegerType):
return False return False
...@@ -329,11 +355,29 @@ class PsIeeeFloatType(PsScalarType): ...@@ -329,11 +355,29 @@ class PsIeeeFloatType(PsScalarType):
@property @property
def width(self) -> int: def width(self) -> int:
return self._width return self._width
@property @property
def itemsize(self) -> int: def itemsize(self) -> int:
return self.width // 8 return self.width // 8
@property
def required_headers(self) -> set[str]:
if self._width == 16:
return {'"half_precision.h"'}
else:
return set()
def create_literal(self, value: Any) -> str:
np_dtype = self.NUMPY_TYPES[self._width]
if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
match self.width:
case 16: return f"((half) {value})" # see include/half_precision.h
case 32: return f"{value}f"
case 64: return str(value)
case _: assert False, "unreachable code"
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width] np_type = self.NUMPY_TYPES[self._width]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment