From f721e16707ce74f813ea3cf814197cd9aa5aaf45 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 18 Jan 2024 17:25:52 +0100
Subject: [PATCH] literal printing and header collection

---
 .../ast/{analysis.py => collectors.py}        | 54 +++++++++++++++----
 src/pystencils/nbackend/ast/kernelfunction.py |  9 ++--
 src/pystencils/nbackend/typed_expressions.py  |  8 ++-
 src/pystencils/nbackend/types/basic_types.py  | 50 +++++++++++++++--
 4 files changed, 103 insertions(+), 18 deletions(-)
 rename src/pystencils/nbackend/ast/{analysis.py => collectors.py} (63%)

diff --git a/src/pystencils/nbackend/ast/analysis.py b/src/pystencils/nbackend/ast/collectors.py
similarity index 63%
rename from src/pystencils/nbackend/ast/analysis.py
rename to src/pystencils/nbackend/ast/collectors.py
index 6a3162c1b..65bc14d4f 100644
--- a/src/pystencils/nbackend/ast/analysis.py
+++ b/src/pystencils/nbackend/ast/collectors.py
@@ -1,11 +1,14 @@
-from typing import cast
+from typing import cast, Any
+
+from functools import reduce
 
 from pymbolic.primitives import Variable
+from pymbolic.mapper import Collector
 from pymbolic.mapper.dependency import DependencyMapper
 
 from .kernelfunction import PsKernelFunction
 from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock
-from ..typed_expressions import PsTypedVariable
+from ..typed_expressions import PsTypedVariable, PsTypedConstant
 from ..exceptions import PsMalformedAstException, PsInternalCompilerError
 
 
@@ -24,12 +27,12 @@ class UndefinedVariablesCollector:
             include_cses=False,
         )
 
-    def collect(self, node: PsAstNode) -> set[PsTypedVariable]:
+    def __call__(self, node: PsAstNode) -> set[PsTypedVariable]:
         """Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
 
         match node:
             case PsKernelFunction(block):
-                return self.collect(block)
+                return self(block)
 
             case PsExpression(expr):
                 variables: set[Variable] = self._pb_dep_mapper(expr)
@@ -43,22 +46,22 @@ class UndefinedVariablesCollector:
                 return cast(set[PsTypedVariable], variables)
 
             case PsAssignment(lhs, rhs):
-                return self.collect(lhs) | self.collect(rhs)
+                return self(lhs) | self(rhs)
 
             case PsBlock(statements):
                 undefined_vars: set[PsTypedVariable] = set()
                 for stmt in statements[::-1]:
                     undefined_vars -= self.declared_variables(stmt)
-                    undefined_vars |= self.collect(stmt)
+                    undefined_vars |= self(stmt)
 
                 return undefined_vars
 
             case PsLoop(ctr, start, stop, step, body):
                 undefined_vars = (
-                    self.collect(start)
-                    | self.collect(stop)
-                    | self.collect(step)
-                    | self.collect(body)
+                    self(start)
+                    | self(stop)
+                    | self(step)
+                    | self(body)
                 )
                 undefined_vars.remove(ctr.symbol)
                 return undefined_vars
@@ -82,3 +85,34 @@ class UndefinedVariablesCollector:
                 raise PsInternalCompilerError(
                     f"Don't know how to collect declared variables from {unknown}"
                 )
+
+
+def collect_undefined_variables(node: PsAstNode) -> set[PsTypedVariable]:
+    return UndefinedVariablesCollector()(node)
+
+
+class RequiredHeadersCollector(Collector):
+    """Collect all header files required by a given AST for correct compilation.
+
+    Required headers can currently only be defined in subclasses of `PsAbstractType`.
+    """
+
+    def __call__(self, node: PsAstNode) -> set[str]:
+        match node:
+            case PsExpression(expr):
+                return self.rec(expr)
+            case node:
+                return reduce(set.union, (self(c) for c in node.children()), set())
+
+    def map_typed_variable(self, var: PsTypedVariable) -> set[str]:
+        return var.dtype.required_headers
+
+    def map_constant(self, cst: Any):
+        if not isinstance(cst, PsTypedConstant):
+            raise PsMalformedAstException("Untyped constant encountered in expression.")
+
+        return cst.dtype.required_headers
+
+
+def collect_required_headers(node: PsAstNode) -> set[str]:
+    return RequiredHeadersCollector()(node)
diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py
index b911f1100..2729d25a6 100644
--- a/src/pystencils/nbackend/ast/kernelfunction.py
+++ b/src/pystencils/nbackend/ast/kernelfunction.py
@@ -129,9 +129,9 @@ class PsKernelFunction(PsAstNode):
         This function performs a full traversal of the AST.
         To improve performance, make sure to cache the result if necessary.
         """
-        from .analysis import UndefinedVariablesCollector
+        from .collectors import collect_undefined_variables
 
-        params_set = UndefinedVariablesCollector().collect(self)
+        params_set = collect_undefined_variables(self)
         params_list = sorted(params_set, key=lambda p: p.name)
 
         arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
@@ -140,5 +140,6 @@ class PsKernelFunction(PsAstNode):
         )
 
     def get_required_headers(self) -> set[str]:
-        #   TODO: Headers from types, vectorizer, ...
-        return set()
+        #   To Do: Headers from target/instruction set/...
+        from .collectors import collect_required_headers
+        return collect_required_headers(self)
diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py
index b33114426..94aa75cf4 100644
--- a/src/pystencils/nbackend/typed_expressions.py
+++ b/src/pystencils/nbackend/typed_expressions.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 from typing import TypeAlias, Any
+from sys import intern
 
 import pymbolic.primitives as pb
 
@@ -16,6 +17,7 @@ class PsTypedVariable(pb.Variable):
     init_arg_names: tuple[str, ...] = ("name", "dtype")
 
     __match_args__ = ("name", "dtype")
+    mapper_method = intern("map_typed_variable")
 
     def __init__(self, name: str, dtype: PsAbstractType):
         super(PsTypedVariable, self).__init__(name)
@@ -98,8 +100,12 @@ class PsTypedConstant:
         self._dtype = constify(dtype)
         self._value = self._dtype.create_constant(value)
 
+    @property
+    def dtype(self) -> PsNumericType:
+        return self._dtype
+
     def __str__(self) -> str:
-        return str(self._value)
+        return self._dtype.create_literal(self._value)
 
     def __repr__(self) -> str:
         return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py
index be6de603e..a2b219658 100644
--- a/src/pystencils/nbackend/types/basic_types.py
+++ b/src/pystencils/nbackend/types/basic_types.py
@@ -32,6 +32,15 @@ class PsAbstractType(ABC):
     def const(self) -> bool:
         return self._const
 
+    #   -------------------------------------------------------------------------------------------
+    #   Optional Info
+    #   -------------------------------------------------------------------------------------------
+
+    @property
+    def required_headers(self) -> set[str]:
+        """The set of header files required when this type occurs in generated code."""
+        return set()
+
     #   -------------------------------------------------------------------------------------------
     #   Internal virtual operations
     #   -------------------------------------------------------------------------------------------
@@ -154,6 +163,14 @@ class PsNumericType(PsAbstractType, ABC):
             PsTypeError: If the given value cannot be interpreted in this type.
         """
 
+    @abstractmethod
+    def create_literal(self, value: Any) -> str:
+        """Create a C numerical literal for a constant of this type.
+
+        Raises:
+            PsTypeError: If the given value's type is not the numeric type's compiler-internal representation.
+        """
+
     @abstractmethod
     def is_int(self) -> bool:
         ...
@@ -185,7 +202,7 @@ class PsScalarType(PsNumericType, ABC):
 
     def is_float(self) -> bool:
         return isinstance(self, PsIeeeFloatType)
-    
+
     @property
     @abstractmethod
     def itemsize(self) -> int:
@@ -202,6 +219,7 @@ class PsIntegerType(PsScalarType, ABC):
     __match_args__ = ("width",)
 
     SUPPORTED_WIDTHS = (8, 16, 32, 64)
+    NUMPY_TYPES: dict[int, type] = dict()
 
     def __init__(self, width: int, signed: bool = True, const: bool = False):
         if width not in self.SUPPORTED_WIDTHS:
@@ -221,11 +239,19 @@ class PsIntegerType(PsScalarType, ABC):
     @property
     def signed(self) -> bool:
         return self._signed
-    
+
     @property
     def itemsize(self) -> int:
         return self.width // 8
 
+    def create_literal(self, value: Any) -> str:
+        np_dtype = self.NUMPY_TYPES[self._width]
+        if not isinstance(value, np_dtype):
+            raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
+        unsigned_suffix = "" if self.signed else "u"
+        #   TODO: cast literal to correct type?
+        return str(value) + unsigned_suffix
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsIntegerType):
             return False
@@ -329,11 +355,29 @@ class PsIeeeFloatType(PsScalarType):
     @property
     def width(self) -> int:
         return self._width
-    
+
     @property
     def itemsize(self) -> int:
         return self.width // 8
 
+    @property
+    def required_headers(self) -> set[str]:
+        if self._width == 16:
+            return {'"half_precision.h"'}
+        else:
+            return set()
+
+    def create_literal(self, value: Any) -> str:
+        np_dtype = self.NUMPY_TYPES[self._width]
+        if not isinstance(value, np_dtype):
+            raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
+
+        match self.width:
+            case 16: return f"((half) {value})"  # see include/half_precision.h
+            case 32: return f"{value}f"
+            case 64: return str(value)
+            case _: assert False, "unreachable code"
+
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
 
-- 
GitLab