From f721e16707ce74f813ea3cf814197cd9aa5aaf45 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 18 Jan 2024 17:25:52 +0100 Subject: [PATCH] literal printing and header collection --- .../ast/{analysis.py => collectors.py} | 54 +++++++++++++++---- src/pystencils/nbackend/ast/kernelfunction.py | 9 ++-- src/pystencils/nbackend/typed_expressions.py | 8 ++- src/pystencils/nbackend/types/basic_types.py | 50 +++++++++++++++-- 4 files changed, 103 insertions(+), 18 deletions(-) rename src/pystencils/nbackend/ast/{analysis.py => collectors.py} (63%) diff --git a/src/pystencils/nbackend/ast/analysis.py b/src/pystencils/nbackend/ast/collectors.py similarity index 63% rename from src/pystencils/nbackend/ast/analysis.py rename to src/pystencils/nbackend/ast/collectors.py index 6a3162c1b..65bc14d4f 100644 --- a/src/pystencils/nbackend/ast/analysis.py +++ b/src/pystencils/nbackend/ast/collectors.py @@ -1,11 +1,14 @@ -from typing import cast +from typing import cast, Any + +from functools import reduce from pymbolic.primitives import Variable +from pymbolic.mapper import Collector from pymbolic.mapper.dependency import DependencyMapper from .kernelfunction import PsKernelFunction from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock -from ..typed_expressions import PsTypedVariable +from ..typed_expressions import PsTypedVariable, PsTypedConstant from ..exceptions import PsMalformedAstException, PsInternalCompilerError @@ -24,12 +27,12 @@ class UndefinedVariablesCollector: include_cses=False, ) - def collect(self, node: PsAstNode) -> set[PsTypedVariable]: + def __call__(self, node: PsAstNode) -> set[PsTypedVariable]: """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage.""" match node: case PsKernelFunction(block): - return self.collect(block) + return self(block) case PsExpression(expr): variables: set[Variable] = self._pb_dep_mapper(expr) @@ -43,22 +46,22 @@ class UndefinedVariablesCollector: return cast(set[PsTypedVariable], variables) case PsAssignment(lhs, rhs): - return self.collect(lhs) | self.collect(rhs) + return self(lhs) | self(rhs) case PsBlock(statements): undefined_vars: set[PsTypedVariable] = set() for stmt in statements[::-1]: undefined_vars -= self.declared_variables(stmt) - undefined_vars |= self.collect(stmt) + undefined_vars |= self(stmt) return undefined_vars case PsLoop(ctr, start, stop, step, body): undefined_vars = ( - self.collect(start) - | self.collect(stop) - | self.collect(step) - | self.collect(body) + self(start) + | self(stop) + | self(step) + | self(body) ) undefined_vars.remove(ctr.symbol) return undefined_vars @@ -82,3 +85,34 @@ class UndefinedVariablesCollector: raise PsInternalCompilerError( f"Don't know how to collect declared variables from {unknown}" ) + + +def collect_undefined_variables(node: PsAstNode) -> set[PsTypedVariable]: + return UndefinedVariablesCollector()(node) + + +class RequiredHeadersCollector(Collector): + """Collect all header files required by a given AST for correct compilation. + + Required headers can currently only be defined in subclasses of `PsAbstractType`. + """ + + def __call__(self, node: PsAstNode) -> set[str]: + match node: + case PsExpression(expr): + return self.rec(expr) + case node: + return reduce(set.union, (self(c) for c in node.children()), set()) + + def map_typed_variable(self, var: PsTypedVariable) -> set[str]: + return var.dtype.required_headers + + def map_constant(self, cst: Any): + if not isinstance(cst, PsTypedConstant): + raise PsMalformedAstException("Untyped constant encountered in expression.") + + return cst.dtype.required_headers + + +def collect_required_headers(node: PsAstNode) -> set[str]: + return RequiredHeadersCollector()(node) diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py index b911f1100..2729d25a6 100644 --- a/src/pystencils/nbackend/ast/kernelfunction.py +++ b/src/pystencils/nbackend/ast/kernelfunction.py @@ -129,9 +129,9 @@ class PsKernelFunction(PsAstNode): This function performs a full traversal of the AST. To improve performance, make sure to cache the result if necessary. """ - from .analysis import UndefinedVariablesCollector + from .collectors import collect_undefined_variables - params_set = UndefinedVariablesCollector().collect(self) + params_set = collect_undefined_variables(self) params_list = sorted(params_set, key=lambda p: p.name) arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer)) @@ -140,5 +140,6 @@ class PsKernelFunction(PsAstNode): ) def get_required_headers(self) -> set[str]: - # TODO: Headers from types, vectorizer, ... - return set() + # To Do: Headers from target/instruction set/... + from .collectors import collect_required_headers + return collect_required_headers(self) diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py index b33114426..94aa75cf4 100644 --- a/src/pystencils/nbackend/typed_expressions.py +++ b/src/pystencils/nbackend/typed_expressions.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TypeAlias, Any +from sys import intern import pymbolic.primitives as pb @@ -16,6 +17,7 @@ class PsTypedVariable(pb.Variable): init_arg_names: tuple[str, ...] = ("name", "dtype") __match_args__ = ("name", "dtype") + mapper_method = intern("map_typed_variable") def __init__(self, name: str, dtype: PsAbstractType): super(PsTypedVariable, self).__init__(name) @@ -98,8 +100,12 @@ class PsTypedConstant: self._dtype = constify(dtype) self._value = self._dtype.create_constant(value) + @property + def dtype(self) -> PsNumericType: + return self._dtype + def __str__(self) -> str: - return str(self._value) + return self._dtype.create_literal(self._value) def __repr__(self) -> str: return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )" diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index be6de603e..a2b219658 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -32,6 +32,15 @@ class PsAbstractType(ABC): def const(self) -> bool: return self._const + # ------------------------------------------------------------------------------------------- + # Optional Info + # ------------------------------------------------------------------------------------------- + + @property + def required_headers(self) -> set[str]: + """The set of header files required when this type occurs in generated code.""" + return set() + # ------------------------------------------------------------------------------------------- # Internal virtual operations # ------------------------------------------------------------------------------------------- @@ -154,6 +163,14 @@ class PsNumericType(PsAbstractType, ABC): PsTypeError: If the given value cannot be interpreted in this type. """ + @abstractmethod + def create_literal(self, value: Any) -> str: + """Create a C numerical literal for a constant of this type. + + Raises: + PsTypeError: If the given value's type is not the numeric type's compiler-internal representation. + """ + @abstractmethod def is_int(self) -> bool: ... @@ -185,7 +202,7 @@ class PsScalarType(PsNumericType, ABC): def is_float(self) -> bool: return isinstance(self, PsIeeeFloatType) - + @property @abstractmethod def itemsize(self) -> int: @@ -202,6 +219,7 @@ class PsIntegerType(PsScalarType, ABC): __match_args__ = ("width",) SUPPORTED_WIDTHS = (8, 16, 32, 64) + NUMPY_TYPES: dict[int, type] = dict() def __init__(self, width: int, signed: bool = True, const: bool = False): if width not in self.SUPPORTED_WIDTHS: @@ -221,11 +239,19 @@ class PsIntegerType(PsScalarType, ABC): @property def signed(self) -> bool: return self._signed - + @property def itemsize(self) -> int: return self.width // 8 + def create_literal(self, value: Any) -> str: + np_dtype = self.NUMPY_TYPES[self._width] + if not isinstance(value, np_dtype): + raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") + unsigned_suffix = "" if self.signed else "u" + # TODO: cast literal to correct type? + return str(value) + unsigned_suffix + def __eq__(self, other: object) -> bool: if not isinstance(other, PsIntegerType): return False @@ -329,11 +355,29 @@ class PsIeeeFloatType(PsScalarType): @property def width(self) -> int: return self._width - + @property def itemsize(self) -> int: return self.width // 8 + @property + def required_headers(self) -> set[str]: + if self._width == 16: + return {'"half_precision.h"'} + else: + return set() + + def create_literal(self, value: Any) -> str: + np_dtype = self.NUMPY_TYPES[self._width] + if not isinstance(value, np_dtype): + raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") + + match self.width: + case 16: return f"((half) {value})" # see include/half_precision.h + case 32: return f"{value}f" + case 64: return str(value) + case _: assert False, "unreachable code" + def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] -- GitLab