From 6a934814e5d7933fdd1c12295c01f9a1b33f46d2 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 31 Jan 2024 17:02:14 +0100
Subject: [PATCH] type erasure for anonymous structs; full translation pass for
 index kernels

---
 src/pystencils/nbackend/arrays.py             | 42 ++++++++---
 src/pystencils/nbackend/emission.py           | 23 ++++--
 src/pystencils/nbackend/functions.py          | 51 +++++++++++++
 .../nbackend/kernelcreation/kernelcreation.py |  2 +
 .../kernelcreation/transformations.py         | 73 +++++++++++++++++++
 src/pystencils/nbackend/types/basic_types.py  |  6 +-
 tests/nbackend/test_basic_printing.py         |  4 +-
 tests/nbackend/types/test_types.py            |  4 +-
 8 files changed, 185 insertions(+), 20 deletions(-)
 create mode 100644 src/pystencils/nbackend/kernelcreation/transformations.py

diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index 21f2035a8..0840b2699 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -50,7 +50,7 @@ from abc import ABC
 
 import pymbolic.primitives as pb
 
-from .types import PsAbstractType, PsPointerType, PsIntegerType, PsSignedIntegerType
+from .types import PsAbstractType, PsPointerType, PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType
 
 from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
 
@@ -110,7 +110,7 @@ class PsLinearizedArray:
     @property
     def name(self):
         return self._name
-    
+
     @property
     def base_pointer(self) -> PsArrayBasePointer:
         return self._base_ptr
@@ -119,9 +119,21 @@ class PsLinearizedArray:
     def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]:
         return self._shape
 
+    @property
+    def shape_spec(self) -> tuple[EllipsisType | int, ...]:
+        return tuple(
+            (s.value if isinstance(s, PsTypedConstant) else ...) for s in self._shape
+        )
+
     @property
     def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]:
         return self._strides
+    
+    @property
+    def strides_spec(self) -> tuple[EllipsisType | int, ...]:
+        return tuple(
+            (s.value if isinstance(s, PsTypedConstant) else ...) for s in self._strides
+        )
 
     @property
     def element_type(self):
@@ -134,12 +146,8 @@ class PsLinearizedArray:
         if these variables would occur in here, an infinite recursion would follow.
         Hence they are filtered and replaced by the ellipsis.
         """
-        shape_clean = tuple(
-            (s if isinstance(s, PsTypedConstant) else ...) for s in self._shape
-        )
-        strides_clean = tuple(
-            (s if isinstance(s, PsTypedConstant) else ...) for s in self._strides
-        )
+        shape_clean = self.shape_spec
+        strides_clean = self.strides_spec
         return (
             self._name,
             self._element_type,
@@ -156,9 +164,11 @@ class PsLinearizedArray:
 
     def __hash__(self) -> int:
         return hash(self._hashable_contents())
-    
+
     def __repr__(self) -> str:
-        return f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
+        return (
+            f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])"
+        )
 
 
 class PsArrayAssocVar(PsTypedVariable, ABC):
@@ -195,6 +205,17 @@ class PsArrayBasePointer(PsArrayAssocVar):
 
     def __getinitargs__(self):
         return self.name, self.array
+    
+
+class TypeErasedBasePointer(PsArrayBasePointer):
+    """Base pointer for arrays whose element type has been erased.
+    
+    Used primarily for arrays of anonymous structs."""
+    def __init__(self, name: str, array: PsLinearizedArray):
+        dtype = PsPointerType(PsUnsignedIntegerType(8))
+        super(PsArrayBasePointer, self).__init__(name, dtype, array)
+
+        self._array = array
 
 
 class PsArrayShapeVar(PsArrayAssocVar):
@@ -244,7 +265,6 @@ class PsArrayStrideVar(PsArrayAssocVar):
 
 
 class PsArrayAccess(pb.Subscript):
-
     mapper_method = intern("map_array_access")
 
     def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant):
diff --git a/src/pystencils/nbackend/emission.py b/src/pystencils/nbackend/emission.py
index e47854506..b89b58224 100644
--- a/src/pystencils/nbackend/emission.py
+++ b/src/pystencils/nbackend/emission.py
@@ -14,21 +14,34 @@ from .ast import (
 )
 from .ast.kernelfunction import PsKernelFunction
 from .typed_expressions import PsTypedVariable
+from .functions import Deref, AddressOf, Cast
 
 
 def emit_code(kernel: PsKernelFunction):
     #   TODO: Specialize for different targets
-    printer = CPrinter()
+    printer = CAstPrinter()
     return printer.print(kernel)
 
 
-class CPrinter:
+class CExpressionsPrinter(CCodeMapper):
+
+    def map_deref(self, deref: Deref, enclosing_prec):
+        return "*"
+    
+    def map_address_of(self, addrof: AddressOf, enclosing_prec):
+        return "&"
+    
+    def map_cast(self, cast: Cast, enclosing_prec):
+        return f"({cast.target_type.c_string()})"
+
+
+class CAstPrinter:
     def __init__(self, indent_width=3):
         self._indent_width = indent_width
 
         self._current_indent_level = 0
 
-        self._pb_cmapper = CCodeMapper()
+        self._expr_printer = CExpressionsPrinter()
 
     def indent(self, line):
         return " " * self._current_indent_level + line
@@ -60,7 +73,7 @@ class CPrinter:
 
     @visit.case(PsExpression)
     def pymb_expression(self, expr: PsExpression):
-        return self._pb_cmapper(expr.expression)
+        return self._expr_printer(expr.expression)
 
     @visit.case(PsDeclaration)
     def declaration(self, decl: PsDeclaration):
@@ -81,7 +94,7 @@ class CPrinter:
     def loop(self, loop: PsLoop):
         ctr_symbol = loop.counter.symbol
         assert isinstance(ctr_symbol, PsTypedVariable)
-        
+
         ctr = ctr_symbol.name
         start_code = self.visit(loop.start)
         stop_code = self.visit(loop.stop)
diff --git a/src/pystencils/nbackend/functions.py b/src/pystencils/nbackend/functions.py
index e7dc4e6cb..190984373 100644
--- a/src/pystencils/nbackend/functions.py
+++ b/src/pystencils/nbackend/functions.py
@@ -13,12 +13,63 @@ TODO: Maybe add a way for the user to register additional functions
 TODO: Figure out the best way to describe function signatures and overloads for typing
 """
 
+from sys import intern
 import pymbolic.primitives as pb
 from abc import ABC, abstractmethod
 
+from .types import PsAbstractType
+from .typed_expressions import ExprOrConstant
+
 
 class PsFunction(pb.FunctionSymbol, ABC):
     @property
     @abstractmethod
     def arg_count(self) -> int:
         "Number of arguments this function takes"
+
+
+class Deref(PsFunction):
+    """Dereferences a pointer."""
+
+    mapper_method = intern("map_deref")
+
+    @property
+    def arg_count(self) -> int:
+        return 1
+
+
+deref = Deref()
+
+
+class AddressOf(PsFunction):
+    """Take the address of an object"""
+
+    mapper_method = intern("map_address_of")
+
+    @property
+    def arg_count(self) -> int:
+        return 1
+
+
+address_of = AddressOf()
+
+
+class Cast(PsFunction):
+    mapper_method = intern("map_cast")
+
+    """An unsafe C-style type cast"""
+
+    def __init__(self, target_type: PsAbstractType):
+        self._target_type = target_type
+
+    @property
+    def arg_count(self) -> int:
+        return 1
+
+    @property
+    def target_type(self) -> PsAbstractType:
+        return self._target_type
+
+
+def cast(target_type: PsAbstractType, arg: ExprOrConstant):
+    return Cast(target_type)(ExprOrConstant)
diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
index f29cd9a13..f95619ed7 100644
--- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py
+++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
@@ -12,6 +12,7 @@ from .iteration_space import (
     create_sparse_iteration_space,
     create_full_iteration_space,
 )
+from .transformations import EraseAnonymousStructTypes
 
 
 def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions):
@@ -45,6 +46,7 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
             raise NotImplementedError("Target platform not implemented")
 
     kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
+    kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast)
 
     #   7. Apply optimizations
     #     - Vectorization
diff --git a/src/pystencils/nbackend/kernelcreation/transformations.py b/src/pystencils/nbackend/kernelcreation/transformations.py
new file mode 100644
index 000000000..c01016d12
--- /dev/null
+++ b/src/pystencils/nbackend/kernelcreation/transformations.py
@@ -0,0 +1,73 @@
+from __future__ import annotations
+
+from typing import TypeVar
+
+import pymbolic.primitives as pb
+from pymbolic.mapper import IdentityMapper
+
+from .context import KernelCreationContext
+
+from ..ast import PsAstNode, PsExpression
+from ..arrays import PsArrayAccess, TypeErasedBasePointer
+from ..typed_expressions import PsTypedConstant
+from ..types import PsStructType, PsPointerType
+from ..functions import deref, address_of, Cast
+
+NodeT = TypeVar("NodeT", bound=PsAstNode)
+
+
+class EraseAnonymousStructTypes(IdentityMapper):
+    """Lower anonymous struct arrays to a byte-array representation.
+
+    Arrays whose element type is an anonymous struct are transformed to arrays with element type UInt(8).
+    Lookups on accesses into these arrays are transformed using type casts.
+    """
+
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+
+    def __call__(self, node: NodeT) -> NodeT:
+        match node:
+            case PsExpression(expr):
+                # descend into expr
+                node.expression = self.rec(expr)
+            case other:
+                for c in other.children:
+                    self(c)
+
+        return node
+
+    def map_lookup(self, lookup: pb.Lookup) -> pb.Expression:
+        aggr = lookup.aggregate
+        if not isinstance(aggr, PsArrayAccess):
+            return lookup
+
+        arr = aggr.array
+        if (
+            not isinstance(arr.element_type, PsStructType)
+            or not arr.element_type.anonymous
+        ):
+            return lookup
+
+        struct_type = arr.element_type
+        struct_size = struct_type.itemsize
+
+        bp = aggr.base_ptr
+        type_erased_bp = TypeErasedBasePointer(bp.name, arr)
+        base_index = aggr.index_tuple[0] * PsTypedConstant(struct_size, self._ctx.index_dtype)
+
+        member_name = lookup.name
+        member = struct_type.get_member(member_name)
+        assert member is not None
+
+        np_struct = struct_type.numpy_dtype
+        assert np_struct is not None
+        assert np_struct.fields is not None
+        member_offset = np_struct.fields[member_name][1]
+
+        byte_index = base_index + PsTypedConstant(member_offset, self._ctx.index_dtype)
+        type_erased_access = PsArrayAccess(type_erased_bp, byte_index)
+
+        cast = Cast(PsPointerType(member.dtype))
+
+        return deref(cast(address_of(type_erased_access)))
diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py
index ce27c6fb3..3f080c313 100644
--- a/src/pystencils/nbackend/types/basic_types.py
+++ b/src/pystencils/nbackend/types/basic_types.py
@@ -209,9 +209,13 @@ class PsStructType(PsAbstractType):
         return self._name is None
 
     @property
-    def numpy_dtype(self) -> np.dtype | None:
+    def numpy_dtype(self) -> np.dtype:
         members = [(m.name, m.dtype.numpy_dtype) for m in self._members]
         return np.dtype(members)
+    
+    @property
+    def itemsize(self) -> int:
+        return self.numpy_dtype.itemsize
 
     def c_string(self) -> str:
         if self._name is None:
diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py
index 8d9fc6483..7d1966882 100644
--- a/tests/nbackend/test_basic_printing.py
+++ b/tests/nbackend/test_basic_printing.py
@@ -6,7 +6,7 @@ from pystencils.nbackend.ast import *
 from pystencils.nbackend.typed_expressions import *
 from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess
 from pystencils.nbackend.types.quick import *
-from pystencils.nbackend.emission import CPrinter
+from pystencils.nbackend.emission import CAstPrinter
 
 def test_basic_kernel():
 
@@ -32,7 +32,7 @@ def test_basic_kernel():
 
     func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
 
-    printer = CPrinter()
+    printer = CAstPrinter()
     code = printer.print(func)
 
     paramlist = func.get_parameters().params
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 082b39205..06ec7db16 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -6,7 +6,9 @@ from pystencils.nbackend.types import *
 from pystencils.nbackend.types.quick import *
 
 
-@pytest.mark.parametrize("Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType])
+@pytest.mark.parametrize(
+    "Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType]
+)
 def test_widths(Type):
     for width in Type.SUPPORTED_WIDTHS:
         assert Type(width).width == width
-- 
GitLab