From 6a934814e5d7933fdd1c12295c01f9a1b33f46d2 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 31 Jan 2024 17:02:14 +0100 Subject: [PATCH] type erasure for anonymous structs; full translation pass for index kernels --- src/pystencils/nbackend/arrays.py | 42 ++++++++--- src/pystencils/nbackend/emission.py | 23 ++++-- src/pystencils/nbackend/functions.py | 51 +++++++++++++ .../nbackend/kernelcreation/kernelcreation.py | 2 + .../kernelcreation/transformations.py | 73 +++++++++++++++++++ src/pystencils/nbackend/types/basic_types.py | 6 +- tests/nbackend/test_basic_printing.py | 4 +- tests/nbackend/types/test_types.py | 4 +- 8 files changed, 185 insertions(+), 20 deletions(-) create mode 100644 src/pystencils/nbackend/kernelcreation/transformations.py diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py index 21f2035a8..0840b2699 100644 --- a/src/pystencils/nbackend/arrays.py +++ b/src/pystencils/nbackend/arrays.py @@ -50,7 +50,7 @@ from abc import ABC import pymbolic.primitives as pb -from .types import PsAbstractType, PsPointerType, PsIntegerType, PsSignedIntegerType +from .types import PsAbstractType, PsPointerType, PsIntegerType, PsUnsignedIntegerType, PsSignedIntegerType from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant @@ -110,7 +110,7 @@ class PsLinearizedArray: @property def name(self): return self._name - + @property def base_pointer(self) -> PsArrayBasePointer: return self._base_ptr @@ -119,9 +119,21 @@ class PsLinearizedArray: def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]: return self._shape + @property + def shape_spec(self) -> tuple[EllipsisType | int, ...]: + return tuple( + (s.value if isinstance(s, PsTypedConstant) else ...) for s in self._shape + ) + @property def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]: return self._strides + + @property + def strides_spec(self) -> tuple[EllipsisType | int, ...]: + return tuple( + (s.value if isinstance(s, PsTypedConstant) else ...) for s in self._strides + ) @property def element_type(self): @@ -134,12 +146,8 @@ class PsLinearizedArray: if these variables would occur in here, an infinite recursion would follow. Hence they are filtered and replaced by the ellipsis. """ - shape_clean = tuple( - (s if isinstance(s, PsTypedConstant) else ...) for s in self._shape - ) - strides_clean = tuple( - (s if isinstance(s, PsTypedConstant) else ...) for s in self._strides - ) + shape_clean = self.shape_spec + strides_clean = self.strides_spec return ( self._name, self._element_type, @@ -156,9 +164,11 @@ class PsLinearizedArray: def __hash__(self) -> int: return hash(self._hashable_contents()) - + def __repr__(self) -> str: - return f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" + return ( + f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" + ) class PsArrayAssocVar(PsTypedVariable, ABC): @@ -195,6 +205,17 @@ class PsArrayBasePointer(PsArrayAssocVar): def __getinitargs__(self): return self.name, self.array + + +class TypeErasedBasePointer(PsArrayBasePointer): + """Base pointer for arrays whose element type has been erased. + + Used primarily for arrays of anonymous structs.""" + def __init__(self, name: str, array: PsLinearizedArray): + dtype = PsPointerType(PsUnsignedIntegerType(8)) + super(PsArrayBasePointer, self).__init__(name, dtype, array) + + self._array = array class PsArrayShapeVar(PsArrayAssocVar): @@ -244,7 +265,6 @@ class PsArrayStrideVar(PsArrayAssocVar): class PsArrayAccess(pb.Subscript): - mapper_method = intern("map_array_access") def __init__(self, base_ptr: PsArrayBasePointer, index: ExprOrConstant): diff --git a/src/pystencils/nbackend/emission.py b/src/pystencils/nbackend/emission.py index e47854506..b89b58224 100644 --- a/src/pystencils/nbackend/emission.py +++ b/src/pystencils/nbackend/emission.py @@ -14,21 +14,34 @@ from .ast import ( ) from .ast.kernelfunction import PsKernelFunction from .typed_expressions import PsTypedVariable +from .functions import Deref, AddressOf, Cast def emit_code(kernel: PsKernelFunction): # TODO: Specialize for different targets - printer = CPrinter() + printer = CAstPrinter() return printer.print(kernel) -class CPrinter: +class CExpressionsPrinter(CCodeMapper): + + def map_deref(self, deref: Deref, enclosing_prec): + return "*" + + def map_address_of(self, addrof: AddressOf, enclosing_prec): + return "&" + + def map_cast(self, cast: Cast, enclosing_prec): + return f"({cast.target_type.c_string()})" + + +class CAstPrinter: def __init__(self, indent_width=3): self._indent_width = indent_width self._current_indent_level = 0 - self._pb_cmapper = CCodeMapper() + self._expr_printer = CExpressionsPrinter() def indent(self, line): return " " * self._current_indent_level + line @@ -60,7 +73,7 @@ class CPrinter: @visit.case(PsExpression) def pymb_expression(self, expr: PsExpression): - return self._pb_cmapper(expr.expression) + return self._expr_printer(expr.expression) @visit.case(PsDeclaration) def declaration(self, decl: PsDeclaration): @@ -81,7 +94,7 @@ class CPrinter: def loop(self, loop: PsLoop): ctr_symbol = loop.counter.symbol assert isinstance(ctr_symbol, PsTypedVariable) - + ctr = ctr_symbol.name start_code = self.visit(loop.start) stop_code = self.visit(loop.stop) diff --git a/src/pystencils/nbackend/functions.py b/src/pystencils/nbackend/functions.py index e7dc4e6cb..190984373 100644 --- a/src/pystencils/nbackend/functions.py +++ b/src/pystencils/nbackend/functions.py @@ -13,12 +13,63 @@ TODO: Maybe add a way for the user to register additional functions TODO: Figure out the best way to describe function signatures and overloads for typing """ +from sys import intern import pymbolic.primitives as pb from abc import ABC, abstractmethod +from .types import PsAbstractType +from .typed_expressions import ExprOrConstant + class PsFunction(pb.FunctionSymbol, ABC): @property @abstractmethod def arg_count(self) -> int: "Number of arguments this function takes" + + +class Deref(PsFunction): + """Dereferences a pointer.""" + + mapper_method = intern("map_deref") + + @property + def arg_count(self) -> int: + return 1 + + +deref = Deref() + + +class AddressOf(PsFunction): + """Take the address of an object""" + + mapper_method = intern("map_address_of") + + @property + def arg_count(self) -> int: + return 1 + + +address_of = AddressOf() + + +class Cast(PsFunction): + mapper_method = intern("map_cast") + + """An unsafe C-style type cast""" + + def __init__(self, target_type: PsAbstractType): + self._target_type = target_type + + @property + def arg_count(self) -> int: + return 1 + + @property + def target_type(self) -> PsAbstractType: + return self._target_type + + +def cast(target_type: PsAbstractType, arg: ExprOrConstant): + return Cast(target_type)(ExprOrConstant) diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index f29cd9a13..f95619ed7 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -12,6 +12,7 @@ from .iteration_space import ( create_sparse_iteration_space, create_full_iteration_space, ) +from .transformations import EraseAnonymousStructTypes def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions): @@ -45,6 +46,7 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti raise NotImplementedError("Target platform not implemented") kernel_ast = platform.materialize_iteration_space(kernel_body, ispace) + kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast) # 7. Apply optimizations # - Vectorization diff --git a/src/pystencils/nbackend/kernelcreation/transformations.py b/src/pystencils/nbackend/kernelcreation/transformations.py new file mode 100644 index 000000000..c01016d12 --- /dev/null +++ b/src/pystencils/nbackend/kernelcreation/transformations.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import TypeVar + +import pymbolic.primitives as pb +from pymbolic.mapper import IdentityMapper + +from .context import KernelCreationContext + +from ..ast import PsAstNode, PsExpression +from ..arrays import PsArrayAccess, TypeErasedBasePointer +from ..typed_expressions import PsTypedConstant +from ..types import PsStructType, PsPointerType +from ..functions import deref, address_of, Cast + +NodeT = TypeVar("NodeT", bound=PsAstNode) + + +class EraseAnonymousStructTypes(IdentityMapper): + """Lower anonymous struct arrays to a byte-array representation. + + Arrays whose element type is an anonymous struct are transformed to arrays with element type UInt(8). + Lookups on accesses into these arrays are transformed using type casts. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + + def __call__(self, node: NodeT) -> NodeT: + match node: + case PsExpression(expr): + # descend into expr + node.expression = self.rec(expr) + case other: + for c in other.children: + self(c) + + return node + + def map_lookup(self, lookup: pb.Lookup) -> pb.Expression: + aggr = lookup.aggregate + if not isinstance(aggr, PsArrayAccess): + return lookup + + arr = aggr.array + if ( + not isinstance(arr.element_type, PsStructType) + or not arr.element_type.anonymous + ): + return lookup + + struct_type = arr.element_type + struct_size = struct_type.itemsize + + bp = aggr.base_ptr + type_erased_bp = TypeErasedBasePointer(bp.name, arr) + base_index = aggr.index_tuple[0] * PsTypedConstant(struct_size, self._ctx.index_dtype) + + member_name = lookup.name + member = struct_type.get_member(member_name) + assert member is not None + + np_struct = struct_type.numpy_dtype + assert np_struct is not None + assert np_struct.fields is not None + member_offset = np_struct.fields[member_name][1] + + byte_index = base_index + PsTypedConstant(member_offset, self._ctx.index_dtype) + type_erased_access = PsArrayAccess(type_erased_bp, byte_index) + + cast = Cast(PsPointerType(member.dtype)) + + return deref(cast(address_of(type_erased_access))) diff --git a/src/pystencils/nbackend/types/basic_types.py b/src/pystencils/nbackend/types/basic_types.py index ce27c6fb3..3f080c313 100644 --- a/src/pystencils/nbackend/types/basic_types.py +++ b/src/pystencils/nbackend/types/basic_types.py @@ -209,9 +209,13 @@ class PsStructType(PsAbstractType): return self._name is None @property - def numpy_dtype(self) -> np.dtype | None: + def numpy_dtype(self) -> np.dtype: members = [(m.name, m.dtype.numpy_dtype) for m in self._members] return np.dtype(members) + + @property + def itemsize(self) -> int: + return self.numpy_dtype.itemsize def c_string(self) -> str: if self._name is None: diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py index 8d9fc6483..7d1966882 100644 --- a/tests/nbackend/test_basic_printing.py +++ b/tests/nbackend/test_basic_printing.py @@ -6,7 +6,7 @@ from pystencils.nbackend.ast import * from pystencils.nbackend.typed_expressions import * from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess from pystencils.nbackend.types.quick import * -from pystencils.nbackend.emission import CPrinter +from pystencils.nbackend.emission import CAstPrinter def test_basic_kernel(): @@ -32,7 +32,7 @@ def test_basic_kernel(): func = PsKernelFunction(PsBlock([loop]), target=Target.CPU) - printer = CPrinter() + printer = CAstPrinter() code = printer.print(func) paramlist = func.get_parameters().params diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py index 082b39205..06ec7db16 100644 --- a/tests/nbackend/types/test_types.py +++ b/tests/nbackend/types/test_types.py @@ -6,7 +6,9 @@ from pystencils.nbackend.types import * from pystencils.nbackend.types.quick import * -@pytest.mark.parametrize("Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType]) +@pytest.mark.parametrize( + "Type", [PsSignedIntegerType, PsUnsignedIntegerType, PsIeeeFloatType] +) def test_widths(Type): for width in Type.SUPPORTED_WIDTHS: assert Type(width).width == width -- GitLab