From 4b5d3d6ff7a5b28083a363242df0cbcb1e8cd258 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 28 Oct 2024 12:27:05 +0100 Subject: [PATCH] Refactor Field Modelling --- docs/source/backend/ast.rst | 15 +- docs/source/backend/index.rst | 3 +- docs/source/backend/objects.rst | 103 +++++++- docs/source/backend/output.rst | 6 + src/pystencils/backend/__init__.py | 8 - src/pystencils/backend/arrays.py | 194 --------------- src/pystencils/backend/ast/analysis.py | 16 +- src/pystencils/backend/ast/expressions.py | 206 +++++++++------- src/pystencils/backend/ast/structural.py | 4 +- src/pystencils/backend/ast/util.py | 12 +- src/pystencils/backend/emission.py | 16 +- src/pystencils/backend/extensions/cpp.py | 2 +- .../backend/jit/cpu_extension_module.py | 44 ++-- src/pystencils/backend/jit/gpu_cupy.py | 64 ++--- .../backend/kernelcreation/ast_factory.py | 4 +- .../backend/kernelcreation/context.py | 228 ++++++++++-------- .../backend/kernelcreation/freeze.py | 19 +- .../backend/kernelcreation/iteration_space.py | 13 +- .../backend/kernelcreation/typification.py | 22 +- src/pystencils/backend/kernelfunction.py | 164 ++++++------- src/pystencils/backend/memory.py | 198 +++++++++++++++ src/pystencils/backend/platforms/cuda.py | 8 +- .../backend/platforms/generic_cpu.py | 14 +- src/pystencils/backend/platforms/sycl.py | 10 +- src/pystencils/backend/platforms/x86.py | 10 +- src/pystencils/backend/properties.py | 41 ++++ src/pystencils/backend/symbols.py | 55 ----- .../backend/transformations/__init__.py | 6 +- .../transformations/canonical_clone.py | 2 +- .../transformations/canonicalize_symbols.py | 2 +- .../transformations/eliminate_constants.py | 2 +- .../erase_anonymous_structs.py | 109 --------- .../hoist_loop_invariant_decls.py | 6 +- .../backend/transformations/lower_to_c.py | 150 ++++++++++++ .../transformations/select_intrinsics.py | 6 +- src/pystencils/boundaries/boundaryhandling.py | 10 +- src/pystencils/kernelcreation.py | 20 +- src/pystencils/types/parsing.py | 4 + src/pystencils/types/types.py | 3 +- tests/nbackend/kernelcreation/test_context.py | 92 ++++--- tests/nbackend/kernelcreation/test_freeze.py | 16 +- .../kernelcreation/test_iteration_space.py | 8 +- .../kernelcreation/test_typification.py | 10 + tests/nbackend/test_ast.py | 91 +++++-- tests/nbackend/test_code_printing.py | 3 +- tests/nbackend/test_cpujit.py | 13 +- tests/nbackend/test_extensions.py | 5 +- tests/nbackend/test_memory.py | 50 ++++ .../test_canonicalize_symbols.py | 2 +- .../test_constant_elimination.py | 2 +- .../transformations/test_lower_to_c.py | 122 ++++++++++ 51 files changed, 1340 insertions(+), 873 deletions(-) create mode 100644 docs/source/backend/output.rst delete mode 100644 src/pystencils/backend/arrays.py create mode 100644 src/pystencils/backend/memory.py create mode 100644 src/pystencils/backend/properties.py delete mode 100644 src/pystencils/backend/symbols.py delete mode 100644 src/pystencils/backend/transformations/erase_anonymous_structs.py create mode 100644 src/pystencils/backend/transformations/lower_to_c.py create mode 100644 tests/nbackend/test_memory.py create mode 100644 tests/nbackend/transformations/test_lower_to_c.py diff --git a/docs/source/backend/ast.rst b/docs/source/backend/ast.rst index 41f230166..44f8f2540 100644 --- a/docs/source/backend/ast.rst +++ b/docs/source/backend/ast.rst @@ -2,29 +2,30 @@ Abstract Syntax Tree ******************** -Inheritance Diagramm -==================== +API Documentation +================= + +Inheritance Diagram +------------------- .. inheritance-diagram:: pystencils.backend.ast.astnode.PsAstNode pystencils.backend.ast.structural pystencils.backend.ast.expressions pystencils.backend.extensions.foreign_ast :top-classes: pystencils.types.PsAstNode :parts: 1 - Base Classes -============ +------------ .. automodule:: pystencils.backend.ast.astnode :members: Structural Nodes -================ +---------------- .. automodule:: pystencils.backend.ast.structural :members: - Expressions -=========== +----------- .. automodule:: pystencils.backend.ast.expressions :members: diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index f2fe9346d..70ed684c6 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -15,6 +15,7 @@ who wish to customize or extend the behaviour of the code generator in their app translation platforms transformations + output jit extensions @@ -30,7 +31,7 @@ The IR comprises *symbols*, *constants*, *arrays*, the *iteration space* and the * `PsSymbol` represents a single symbol in the kernel, annotated with a type. Other than in the frontend, uniqueness of symbols is enforced by the backend: of each symbol, at most one instance may exist. * `PsConstant` provides a type-safe representation of constants. -* `PsLinearizedArray` is the backend counterpart to the ubiquitous `Field`, representing a contiguous +* `PsBuffer` is the backend counterpart to the ubiquitous `Field`, representing a contiguous n-dimensional array. These arrays do not occur directly in the IR, but are represented through their *associated symbols*, which are base pointers, shapes, and strides. diff --git a/docs/source/backend/objects.rst b/docs/source/backend/objects.rst index 1b36842b8..11cf8ea5e 100644 --- a/docs/source/backend/objects.rst +++ b/docs/source/backend/objects.rst @@ -2,8 +2,8 @@ Constants and Memory Objects **************************** -Memory Objects: Symbols and Field Arrays -======================================== +Memory Objects: Symbols and Buffers +=================================== The Memory Model ---------------- @@ -12,32 +12,111 @@ In order to reason about memory accesses, mutability, invariance, and aliasing, a very simple memory model. There are three types of memory objects: - Symbols (`PsSymbol`), which act as registers for data storage within the scope of a kernel -- Field arrays (`PsLinearizedArray`), which represent a contiguous block of memory the kernel has access to, and +- Field buffers (`PsBuffer`), which represent a contiguous block of memory the kernel has access to, and - the *unmanaged heap*, which is a global catch-all memory object which all pointers not belonging to a field array point into. All of these objects are disjoint, and cannot alias each other. Each symbol exists in isolation, -field arrays do not overlap, +field buffers do not overlap, and raw pointers are assumed not to point into memory owned by a symbol or field array. Instead, all raw pointers point into unmanaged heap memory, and are assumed to *always* alias one another: Each change brought to unmanaged memory by one raw pointer is assumed to affect the memory pointed to by another raw pointer. -Classes +Symbols ------- -.. autoclass:: pystencils.backend.symbols.PsSymbol - :members: +In the pystencils IR, instances of `PsSymbol` represent what is generally known as "virtual registers". +These are memory locations that are private to a function, cannot be aliased or pointed to, and will finally reside +either in physical registers or on the stack. +Each symbol has a name and a data type. The data type may initially be `None`, in which case it should soon after be +determined by the `Typifier`. -.. automodule:: pystencils.backend.arrays - :members: +Other than their front-end counterpart `sympy.Symbol <sympy.core.symbol.Symbol>`, `PsSymbol` instances are mutable; +their properties can and often will change over time. +As a consequence, they are not comparable by value: +two `PsSymbol` instances with the same name and data type will in general *not* be equal. +In fact, most of the time, it is an error to have two identical symbol instances active. + +Creating Symbols +^^^^^^^^^^^^^^^^ + +During kernel translation, symbols never exist in isolation, but should always be managed by a `KernelCreationContext`. +Symbols can be created and retrieved using `add_symbol <KernelCreationContext.add_symbol>` and `find_symbol <KernelCreationContext.find_symbol>`. +A symbol can also be duplicated using `duplicate_symbol <KernelCreationContext.duplicate_symbol>`, which assigns a new name to the symbol's copy. +The `KernelCreationContext` keeps track of all existing symbols during a kernel translation run +and makes sure that no name and data type conflicts may arise. + +Never call the constructor of `PsSymbol` directly unless you really know what you are doing. + +Symbol Properties +^^^^^^^^^^^^^^^^^ +Symbols can be annotated with arbitrary information using *symbol properties*. +Each symbol property type must be a subclass of `PsSymbolProperty`. +It is strongly recommended to implement property types using frozen +`dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. +For example, this snippet defines a property type that models pointer alignment requirements: -Constants and Literals -====================== +.. code-block:: python + + @dataclass(frozen=True) + class AlignmentProperty(UniqueSymbolProperty) + """Require this pointer symbol to be aligned at a particular byte boundary.""" + + byte_boundary: int + +Inheriting from `UniqueSymbolProperty` ensures that at most one property of this type can be attached to +a symbol at any time. +Properties can be added, queried, and removed using the `PsSymbol` properties API listed below. + +Many symbol properties are more relevant to consumers of generated kernels than to the code generator itself. +The above alignment property, for instance, may be added to a pointer symbol by a vectorization pass +to document its assumption that the pointer be properly aligned, in order to emit aligned load and store instructions. +It then becomes the responsibility of the runtime system embedding the kernel to check this prequesite before calling the kernel. +To make sure this information becomes visible, any properties attached to symbols exposed as kernel parameters will also +be added to their respective `KernelParameter` instance. + +Buffers +------- + +Buffers, as represented by the `PsBuffer` class, represent contiguous, n-dimensional, linearized cuboid blocks of memory. +Each buffer has a fixed name and element data type, +and will be represented in the IR via three sets of symbols: + +- The *base pointer* is a symbol of pointer type which points into the buffer's underlying memory area. + Each buffer has at least one, its primary base pointer, whose pointed-to type must be the same as the + buffer's element type. There may be additional base pointers pointing into subsections of that memory. + These additional base pointers may also have deviating data types, as is for instance required for + type erasure in certain cases. + To communicate its role to the code generation system, + each base pointer needs to be marked as such using the `BufferBasePtr` property, + . +- The buffer *shape* defines the size of the buffer in each dimension. Each shape entry is either a `symbol <PsSymbol>` + or a `constant <PsConstant>`. +- The buffer *strides* define the step size to go from one entry to the next in each dimension. + Like the shape, each stride entry is also either a symbol or a constant. + +The shape and stride symbols must all have the same data type, which will be stored as the buffer's index data type. + +Creating and Managing Buffers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similarily to symbols, buffers are typically managed by the `KernelCreationContext`, which associates each buffer +to a front-end `Field`. Buffers for fields can be obtained using `get_buffer <KernelCreationContext.get_buffer>`. +The context makes sure to avoid name conflicts between buffers. + +API Documentation +================= + +.. automodule:: pystencils.backend.properties + :members: + +.. automodule:: pystencils.backend.memory + :members: -.. autoclass:: pystencils.backend.constants.PsConstant +.. automodule:: pystencils.backend.constants :members: .. autoclass:: pystencils.backend.literals.PsLiteral diff --git a/docs/source/backend/output.rst b/docs/source/backend/output.rst new file mode 100644 index 000000000..9875e257b --- /dev/null +++ b/docs/source/backend/output.rst @@ -0,0 +1,6 @@ +********************* +Code Generator Output +********************* + +.. automodule:: pystencils.backend.kernelfunction + :members: diff --git a/src/pystencils/backend/__init__.py b/src/pystencils/backend/__init__.py index a0b1c8f74..b947a112e 100644 --- a/src/pystencils/backend/__init__.py +++ b/src/pystencils/backend/__init__.py @@ -1,9 +1,5 @@ from .kernelfunction import ( KernelParameter, - FieldParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, KernelFunction, GpuKernelFunction, ) @@ -12,10 +8,6 @@ from .constraints import KernelParamsConstraint __all__ = [ "KernelParameter", - "FieldParameter", - "FieldShapeParam", - "FieldStrideParam", - "FieldPointerParam", "KernelFunction", "GpuKernelFunction", "KernelParamsConstraint", diff --git a/src/pystencils/backend/arrays.py b/src/pystencils/backend/arrays.py deleted file mode 100644 index 9aefeaf62..000000000 --- a/src/pystencils/backend/arrays.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -from typing import Sequence -from types import EllipsisType - -from abc import ABC - -from .constants import PsConstant -from ..types import ( - PsType, - PsPointerType, - PsIntegerType, - PsUnsignedIntegerType, -) - -from .symbols import PsSymbol -from ..defaults import DEFAULTS - - -class PsLinearizedArray: - """Class to model N-dimensional contiguous arrays. - - **Memory Layout, Shape and Strides** - - The memory layout of an array is defined by its shape and strides. - Both shape and stride entries may either be constants or special variables associated with - exactly one array. - - Shape and strides may be specified at construction in the following way. - For constant entries, their value must be given as an integer. - For variable shape entries and strides, the Ellipsis `...` must be passed instead. - Internally, the passed ``index_dtype`` will be used to create typed constants (`PsConstant`) - and variables (`PsArrayShapeSymbol` and `PsArrayStrideSymbol`) from the passed values. - """ - - def __init__( - self, - name: str, - element_type: PsType, - shape: Sequence[int | str | EllipsisType], - strides: Sequence[int | str | EllipsisType], - index_dtype: PsIntegerType = DEFAULTS.index_dtype, - ): - self._name = name - self._element_type = element_type - self._index_dtype = index_dtype - - if len(shape) != len(strides): - raise ValueError("Shape and stride tuples must have the same length") - - def make_shape(coord, name_or_val): - match name_or_val: - case EllipsisType(): - return PsArrayShapeSymbol(DEFAULTS.field_shape_name(name, coord), self, coord) - case str(): - return PsArrayShapeSymbol(name_or_val, self, coord) - case _: - return PsConstant(name_or_val, index_dtype) - - self._shape: tuple[PsArrayShapeSymbol | PsConstant, ...] = tuple( - make_shape(i, s) for i, s in enumerate(shape) - ) - - def make_stride(coord, name_or_val): - match name_or_val: - case EllipsisType(): - return PsArrayStrideSymbol(DEFAULTS.field_stride_name(name, coord), self, coord) - case str(): - return PsArrayStrideSymbol(name_or_val, self, coord) - case _: - return PsConstant(name_or_val, index_dtype) - - self._strides: tuple[PsArrayStrideSymbol | PsConstant, ...] = tuple( - make_stride(i, s) for i, s in enumerate(strides) - ) - - self._base_ptr = PsArrayBasePointer(DEFAULTS.field_pointer_name(name), self) - - @property - def name(self): - """The array's name""" - return self._name - - @property - def base_pointer(self) -> PsArrayBasePointer: - """The array's base pointer""" - return self._base_ptr - - @property - def shape(self) -> tuple[PsArrayShapeSymbol | PsConstant, ...]: - """The array's shape, expressed using `PsConstant` and `PsArrayShapeSymbol`""" - return self._shape - - @property - def strides(self) -> tuple[PsArrayStrideSymbol | PsConstant, ...]: - """The array's strides, expressed using `PsConstant` and `PsArrayStrideSymbol`""" - return self._strides - - @property - def index_type(self) -> PsIntegerType: - return self._index_dtype - - @property - def element_type(self) -> PsType: - return self._element_type - - def __repr__(self) -> str: - return ( - f"PsLinearizedArray({self._name}: {self.element_type}[{len(self.shape)}D])" - ) - - -class PsArrayAssocSymbol(PsSymbol, ABC): - """A variable that is associated to an array. - - Instances of this class represent pointers and indexing information bound - to a particular array. - """ - - __match_args__ = ("name", "dtype", "array") - - def __init__(self, name: str, dtype: PsType, array: PsLinearizedArray): - super().__init__(name, dtype) - self._array = array - - @property - def array(self) -> PsLinearizedArray: - return self._array - - -class PsArrayBasePointer(PsArrayAssocSymbol): - def __init__(self, name: str, array: PsLinearizedArray): - dtype = PsPointerType(array.element_type) - super().__init__(name, dtype, array) - - self._array = array - - -class TypeErasedBasePointer(PsArrayBasePointer): - """Base pointer for arrays whose element type has been erased. - - Used primarily for arrays of anonymous structs.""" - - def __init__(self, name: str, array: PsLinearizedArray): - dtype = PsPointerType(PsUnsignedIntegerType(8)) - super(PsArrayBasePointer, self).__init__(name, dtype, array) - - self._array = array - - -class PsArrayShapeSymbol(PsArrayAssocSymbol): - """Variable that represents an array's shape in one coordinate. - - Do not instantiate this class yourself, but only use its instances - as provided by `PsLinearizedArray.shape`. - """ - - __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - - def __init__( - self, - name: str, - array: PsLinearizedArray, - coordinate: int, - ): - super().__init__(name, array.index_type, array) - self._coordinate = coordinate - - @property - def coordinate(self) -> int: - return self._coordinate - - -class PsArrayStrideSymbol(PsArrayAssocSymbol): - """Variable that represents an array's stride in one coordinate. - - Do not instantiate this class yourself, but only use its instances - as provided by `PsLinearizedArray.strides`. - """ - - __match_args__ = PsArrayAssocSymbol.__match_args__ + ("coordinate",) - - def __init__( - self, - name: str, - array: PsLinearizedArray, - coordinate: int, - ): - super().__init__(name, array.index_type, array) - self._coordinate = coordinate - - @property - def coordinate(self) -> int: - return self._coordinate diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 5b37470cc..3c6d2ef55 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -16,7 +16,7 @@ from .structural import ( ) from .expressions import ( PsAdd, - PsArrayAccess, + PsBufferAcc, PsCall, PsConstantExpr, PsDiv, @@ -28,9 +28,11 @@ from .expressions import ( PsSub, PsSymbolExpr, PsTernary, + PsSubscript, + PsMemAcc ) -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ...types import PsNumericType @@ -282,8 +284,14 @@ class OperationCounter: case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_): return OperationCounts() - case PsArrayAccess(_, index): - return self.visit_expr(index) + case PsBufferAcc(_, indices) | PsSubscript(_, indices): + return reduce( + operator.add, + (self.visit_expr(idx) for idx in indices) + ) + + case PsMemAcc(_, offset): + return self.visit_expr(offset) case PsCall(_, args): return OperationCounts(calls=1) + reduce( diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 151f86c6e..d73b1faa7 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -7,16 +7,13 @@ import operator import numpy as np from numpy.typing import NDArray -from ..symbols import PsSymbol +from ..memory import PsSymbol, PsBuffer, BufferBasePtr from ..constants import PsConstant from ..literals import PsLiteral -from ..arrays import PsLinearizedArray, PsArrayBasePointer from ..functions import PsFunction from ...types import ( PsType, - PsScalarType, PsVectorType, - PsTypeError, ) from .util import failing_cast from ..exceptions import PsInternalCompilerError @@ -37,6 +34,10 @@ class PsExpression(PsAstNode, ABC): The type annotations are used by various transformation passes to make decisions, e.g. in function materialization and intrinsic selection. + + .. attention:: + The ``structurally_equal`` check currently does not take expression data types into + account. This may change in the future. """ def __init__(self, dtype: PsType | None = None) -> None: @@ -97,8 +98,26 @@ class PsExpression(PsAstNode, ABC): else: raise ValueError(f"Cannot make expression out of {obj}") + def clone(self): + """Clone this expression. + + .. note:: + Subclasses of `PsExpression` should not override this method, + but implement `_clone_expr` instead. + That implementation shall call `clone` on any of its subexpressions, + but does not need to fix the `dtype` property. + The `dtype` is correctly applied by `PsExpression.clone` internally. + """ + cloned = self._clone_expr() + cloned._dtype = self.dtype + return cloned + @abstractmethod - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: + """Implementation of expression cloning. + + :meta public: + """ pass @@ -124,7 +143,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): def symbol(self, symbol: PsSymbol): self._symbol = symbol - def clone(self) -> PsSymbolExpr: + def _clone_expr(self) -> PsSymbolExpr: return PsSymbolExpr(self._symbol) def structurally_equal(self, other: PsAstNode) -> bool: @@ -152,7 +171,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def constant(self, c: PsConstant): self._constant = c - def clone(self) -> PsConstantExpr: + def _clone_expr(self) -> PsConstantExpr: return PsConstantExpr(self._constant) def structurally_equal(self, other: PsAstNode) -> bool: @@ -180,7 +199,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): def literal(self, lit: PsLiteral): self._literal = lit - def clone(self) -> PsLiteralExpr: + def _clone_expr(self) -> PsLiteralExpr: return PsLiteralExpr(self._literal) def structurally_equal(self, other: PsAstNode) -> bool: @@ -193,6 +212,63 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): return f"PsLiteralExpr({repr(self._literal)})" +class PsBufferAcc(PsLvalue, PsExpression): + """Access into a `PsBuffer`.""" + + __match_args__ = ("base_pointer", "index") + + def __init__(self, base_ptr: PsSymbol, index: Sequence[PsExpression]): + super().__init__() + bptr_prop = cast(BufferBasePtr, base_ptr.get_properties(BufferBasePtr).pop()) + + if len(index) != bptr_prop.buffer.dim: + raise ValueError("Number of index expressions must equal buffer shape.") + + self._base_ptr = PsExpression.make(base_ptr) + self._index = list(index) + self._dtype = bptr_prop.buffer.element_type + + @property + def base_pointer(self) -> PsSymbolExpr: + return self._base_ptr + + @base_pointer.setter + def base_pointer(self, expr: PsSymbolExpr): + bptr_prop = cast(BufferBasePtr, expr.symbol.get_properties(BufferBasePtr).pop()) + if bptr_prop.buffer != self.buffer: + raise ValueError( + "Cannot replace a buffer access's base pointer with one belonging to a different buffer." + ) + + self._base_ptr = expr + + @property + def buffer(self) -> PsBuffer: + return cast( + BufferBasePtr, self._base_ptr.symbol.get_properties(BufferBasePtr).pop() + ).buffer + + @property + def index(self) -> list[PsExpression]: + return self._index + + def get_children(self) -> tuple[PsAstNode, ...]: + return (self._base_ptr,) + tuple(self._index) + + def set_child(self, idx: int, c: PsAstNode): + idx = range(len(self._index) + 1)[idx] + if idx == 0: + self.base_pointer = failing_cast(PsSymbolExpr, c) + else: + self._index[idx - 1] = failing_cast(PsExpression, c) + + def _clone_expr(self) -> PsBufferAcc: + return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index]) + + def __repr__(self) -> str: + return f"PsBufferAcc({repr(self._base_ptr)}, {repr(self._index)})" + + class PsSubscript(PsLvalue, PsExpression): """N-dimensional subscript into an array.""" @@ -223,7 +299,7 @@ class PsSubscript(PsLvalue, PsExpression): def index(self, idx: Sequence[PsExpression]): self._index = list(idx) - def clone(self) -> PsSubscript: + def _clone_expr(self) -> PsSubscript: return PsSubscript(self._arr.clone(), [i.clone() for i in self._index]) def get_children(self) -> tuple[PsAstNode, ...]: @@ -239,7 +315,7 @@ class PsSubscript(PsLvalue, PsExpression): def __repr__(self) -> str: idx = ", ".join(repr(i) for i in self._index) - return f"PsSubscript({self._arr}, ({idx}))" + return f"PsSubscript({repr(self._arr)}, {repr(idx)})" class PsMemAcc(PsLvalue, PsExpression): @@ -268,7 +344,7 @@ class PsMemAcc(PsLvalue, PsExpression): def offset(self, expr: PsExpression): self._offset = expr - def clone(self) -> PsMemAcc: + def _clone_expr(self) -> PsMemAcc: return PsMemAcc(self._ptr.clone(), self._offset.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -286,83 +362,28 @@ class PsMemAcc(PsLvalue, PsExpression): return f"PsMemAcc({repr(self._ptr)}, {repr(self._offset)})" -class PsArrayAccess(PsMemAcc): - __match_args__ = ("base_ptr", "index") - - def __init__(self, base_ptr: PsArrayBasePointer, index: PsExpression): - super().__init__(PsExpression.make(base_ptr), index) - self._base_ptr = base_ptr - self._dtype = base_ptr.array.element_type - - @property - def base_ptr(self) -> PsArrayBasePointer: - return self._base_ptr +class PsVectorMemAcc(PsMemAcc): + """Pointer-based vectorized memory access.""" - @property - def pointer(self) -> PsExpression: - return self._ptr - - @pointer.setter - def pointer(self, expr: PsExpression): - if not isinstance(expr, PsSymbolExpr) or not isinstance( - expr.symbol, PsArrayBasePointer - ): - raise ValueError( - "Base expression of PsArrayAccess must be an array base pointer" - ) - - self._base_ptr = expr.symbol - self._ptr = expr - - @property - def array(self) -> PsLinearizedArray: - return self._base_ptr.array - - @property - def index(self) -> PsExpression: - return self._offset - - @index.setter - def index(self, expr: PsExpression): - self._offset = expr - - def clone(self) -> PsArrayAccess: - return PsArrayAccess(self._base_ptr, self._offset.clone()) - - def __repr__(self) -> str: - return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._offset)})" - - -class PsVectorArrayAccess(PsArrayAccess): __match_args__ = ("base_ptr", "base_index") def __init__( self, - base_ptr: PsArrayBasePointer, + base_ptr: PsExpression, base_index: PsExpression, vector_entries: int, stride: int = 1, alignment: int = 0, ): super().__init__(base_ptr, base_index) - element_type = base_ptr.array.element_type - if not isinstance(element_type, PsScalarType): - raise PsTypeError( - "Cannot generate vector accesses to arrays with non-scalar elements" - ) - - self._vector_type = PsVectorType( - element_type, vector_entries, const=element_type.const - ) + self._vector_entries = vector_entries self._stride = stride self._alignment = alignment - self._dtype = self._vector_type - @property def vector_entries(self) -> int: - return self._vector_type.vector_entries + return self._vector_entries @property def stride(self) -> int: @@ -375,9 +396,9 @@ class PsVectorArrayAccess(PsArrayAccess): def get_vector_type(self) -> PsVectorType: return cast(PsVectorType, self._dtype) - def clone(self) -> PsVectorArrayAccess: - return PsVectorArrayAccess( - self._base_ptr, + def _clone_expr(self) -> PsVectorMemAcc: + return PsVectorMemAcc( + self._ptr.clone(), self._offset.clone(), self.vector_entries, self._stride, @@ -385,12 +406,12 @@ class PsVectorArrayAccess(PsArrayAccess): ) def structurally_equal(self, other: PsAstNode) -> bool: - if not isinstance(other, PsVectorArrayAccess): + if not isinstance(other, PsVectorMemAcc): return False return ( super().structurally_equal(other) - and self._vector_type == other._vector_type + and self._vector_entries == other._vector_entries and self._stride == other._stride and self._alignment == other._alignment ) @@ -420,7 +441,7 @@ class PsLookup(PsExpression, PsLvalue): def member_name(self, name: str): self._name = name - def clone(self) -> PsLookup: + def _clone_expr(self) -> PsLookup: return PsLookup(self._aggregate.clone(), self._member_name) def get_children(self) -> tuple[PsAstNode, ...]: @@ -430,6 +451,9 @@ class PsLookup(PsExpression, PsLvalue): idx = [0][idx] self._aggregate = failing_cast(PsExpression, c) + def __repr__(self) -> str: + return f"PsLookup({repr(self._aggregate)}, {repr(self._member_name)})" + class PsCall(PsExpression): __match_args__ = ("function", "args") @@ -470,7 +494,7 @@ class PsCall(PsExpression): self._args = list(exprs) - def clone(self) -> PsCall: + def _clone_expr(self) -> PsCall: return PsCall(self._function, [arg.clone() for arg in self._args]) def get_children(self) -> tuple[PsAstNode, ...]: @@ -514,7 +538,7 @@ class PsTernary(PsExpression): def case_else(self) -> PsExpression: return self._else - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone()) def get_children(self) -> tuple[PsExpression, ...]: @@ -564,7 +588,7 @@ class PsUnOp(PsExpression): def operand(self, expr: PsExpression): self._operand = expr - def clone(self) -> PsUnOp: + def _clone_expr(self) -> PsUnOp: return type(self)(self._operand.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -591,14 +615,15 @@ class PsNeg(PsUnOp, PsNumericOpTrait): class PsAddressOf(PsUnOp): """Take the address of a memory location. - + .. DANGER:: Taking the address of a memory location owned by a symbol or field array introduces an alias to that memory location. As pystencils assumes its symbols and fields to never be aliased, this can - subtly change the semantics of a kernel. + subtly change the semantics of a kernel. Use the address-of operator with utmost care. """ + pass @@ -617,7 +642,7 @@ class PsCast(PsUnOp): def target_type(self, dtype: PsType): self._target_type = dtype - def clone(self) -> PsUnOp: + def _clone_expr(self) -> PsUnOp: return PsCast(self._target_type, self._operand.clone()) def structurally_equal(self, other: PsAstNode) -> bool: @@ -653,7 +678,7 @@ class PsBinOp(PsExpression): def operand2(self, expr: PsExpression): self._op2 = expr - def clone(self) -> PsBinOp: + def _clone_expr(self) -> PsBinOp: return type(self)(self._op1.clone(), self._op2.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -813,18 +838,21 @@ class PsArrayInitList(PsExpression): __match_args__ = ("items",) - def __init__(self, items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]]): + def __init__( + self, + items: Sequence[PsExpression | Sequence[PsExpression | Sequence[PsExpression]]], + ): super().__init__() self._items = np.array(items, dtype=np.object_) @property def items_grid(self) -> NDArray[np.object_]: return self._items - + @property def shape(self) -> tuple[int, ...]: return self._items.shape - + @property def items(self) -> tuple[PsExpression, ...]: return tuple(self._items.flat) # type: ignore @@ -835,8 +863,8 @@ class PsArrayInitList(PsExpression): def set_child(self, idx: int, c: PsAstNode): self._items.flat[idx] = failing_cast(PsExpression, c) - def clone(self) -> PsExpression: - return PsArrayInitList( + def _clone_expr(self) -> PsExpression: + return PsArrayInitList( np.array([expr.clone() for expr in self.children]).reshape( # type: ignore self._items.shape ) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index cd3aae30d..3ae462c41 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -4,7 +4,7 @@ from types import NoneType from .astnode import PsAstNode, PsLeafMixIn from .expressions import PsExpression, PsLvalue, PsSymbolExpr -from ..symbols import PsSymbol +from ..memory import PsSymbol from .util import failing_cast @@ -320,7 +320,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. Args: - text: The pragma's text, without the ``#pragma ``. + text: The pragma's text, without the ``#pragma``. """ __match_args__ = ("text",) diff --git a/src/pystencils/backend/ast/util.py b/src/pystencils/backend/ast/util.py index 72aff0a01..288097a90 100644 --- a/src/pystencils/backend/ast/util.py +++ b/src/pystencils/backend/ast/util.py @@ -2,8 +2,8 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING, cast from ..exceptions import PsInternalCompilerError -from ..symbols import PsSymbol -from ..arrays import PsLinearizedArray +from ..memory import PsSymbol +from ..memory import PsBuffer from ...types import PsDereferencableType @@ -47,7 +47,7 @@ class AstEqWrapper: def determine_memory_object( expr: PsExpression, -) -> tuple[PsSymbol | PsLinearizedArray | None, bool]: +) -> tuple[PsSymbol | PsBuffer | None, bool]: """Return the memory object accessed by the given expression, together with its constness Returns: @@ -59,7 +59,7 @@ def determine_memory_object( PsLookup, PsSymbolExpr, PsMemAcc, - PsArrayAccess, + PsBufferAcc, ) while isinstance(expr, (PsSubscript, PsLookup)): @@ -74,9 +74,9 @@ def determine_memory_object( return symb, symb.get_dtype().const case PsMemAcc(ptr, _): return None, cast(PsDereferencableType, ptr.get_dtype()).base_type.const - case PsArrayAccess(ptr, _): + case PsBufferAcc(ptr, _): return ( - expr.array, + expr.buffer, cast(PsDereferencableType, ptr.get_dtype()).base_type.const, ) case _: diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 579d47648..6196d69be 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -39,7 +39,7 @@ from .ast.expressions import ( PsSub, PsSymbolExpr, PsLiteralExpr, - PsVectorArrayAccess, + PsVectorMemAcc, PsTernary, PsAnd, PsOr, @@ -50,12 +50,14 @@ from .ast.expressions import ( PsLt, PsGe, PsLe, - PsSubscript + PsSubscript, + PsBufferAcc, ) from .extensions.foreign_ast import PsForeignExpression -from .symbols import PsSymbol +from .exceptions import PsInternalCompilerError +from .memory import PsSymbol from ..types import PsScalarType, PsArrayType from .kernelfunction import KernelFunction, GpuKernelFunction @@ -268,7 +270,7 @@ class CAstPrinter: case PsLiteralExpr(lit): return lit.text - case PsVectorArrayAccess(): + case PsVectorMemAcc(): raise EmissionError("Cannot print vectorized array accesses") case PsMemAcc(base, offset): @@ -386,6 +388,12 @@ class CAstPrinter: foreign_code = node.get_code(self.visit(c, pc) for c in children) pc.pop_op() return foreign_code + + case PsBufferAcc(): + raise PsInternalCompilerError( + f"Unable to print C code for buffer access {node}.\n" + f"Buffer accesses must be lowered using the `LowerToC` pass before emission." + ) case _: raise NotImplementedError(f"Don't know how to print {node}") diff --git a/src/pystencils/backend/extensions/cpp.py b/src/pystencils/backend/extensions/cpp.py index 1055b79e9..025f4a3fb 100644 --- a/src/pystencils/backend/extensions/cpp.py +++ b/src/pystencils/backend/extensions/cpp.py @@ -25,7 +25,7 @@ class CppMethodCall(PsForeignExpression): return super().structurally_equal(other) and self._method == other._method - def clone(self) -> CppMethodCall: + def _clone_expr(self) -> CppMethodCall: return CppMethodCall( cast(PsExpression, self.children[0]), self._method, diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index b9b793589..d7f644550 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -13,11 +13,8 @@ from ..exceptions import PsInternalCompilerError from ..kernelfunction import ( KernelFunction, KernelParameter, - FieldParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, ) +from ..properties import FieldBasePtr, FieldShape, FieldStride from ..constraints import KernelParamsConstraint from ...types import ( PsType, @@ -209,7 +206,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._array_extractions: dict[Field, str] = dict() self._array_frees: dict[Field, str] = dict() - self._array_assoc_var_extractions: dict[FieldParameter, str] = dict() + self._array_assoc_var_extractions: dict[KernelParameter, str] = dict() self._scalar_extractions: dict[KernelParameter, str] = dict() self._constraint_checks: list[str] = [] @@ -282,31 +279,34 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_array_assoc_var(self, param: FieldParameter) -> str: + def extract_array_assoc_var(self, param: KernelParameter) -> str: if param not in self._array_assoc_var_extractions: - field = param.field + field = param.fields[0] buffer = self.extract_field(field) - match param: - case FieldPointerParam(): - code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" - case FieldShapeParam(): - coord = param.coordinate - code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" - case FieldStrideParam(): - coord = param.coordinate - code = ( - f"{param.dtype} {param.name} = " - f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" - ) - case _: - assert False, "unreachable code" + code: str | None = None + + for prop in param.properties: + match prop: + case FieldBasePtr(): + code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" + break + case FieldShape(_, coord): + code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" + break + case FieldStride(_, coord): + code = ( + f"{param.dtype} {param.name} = " + f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" + ) + break + assert code is not None self._array_assoc_var_extractions[param] = code return param.name def extract_parameter(self, param: KernelParameter): - if isinstance(param, FieldParameter): + if param.is_field_parameter: self.extract_array_assoc_var(param) else: self.extract_scalar(param) diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py index d6aaac2d2..7f38d9d43 100644 --- a/src/pystencils/backend/jit/gpu_cupy.py +++ b/src/pystencils/backend/jit/gpu_cupy.py @@ -16,11 +16,9 @@ from .jit import JitBase, JitError, KernelWrapper from ..kernelfunction import ( KernelFunction, GpuKernelFunction, - FieldPointerParam, - FieldShapeParam, - FieldStrideParam, KernelParameter, ) +from ..properties import FieldShape, FieldStride, FieldBasePtr from ..emission import emit_code from ...types import PsStructType @@ -98,8 +96,8 @@ class CupyKernelWrapper(KernelWrapper): field_shapes = set() index_shapes = set() - def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray): - field = field_ptr.field + def check_shape(field_ptr: KernelParameter, arr: cp.ndarray): + field = field_ptr.fields[0] if field.has_fixed_shape: expected_shape = tuple(int(s) for s in field.shape) @@ -118,7 +116,7 @@ class CupyKernelWrapper(KernelWrapper): if isinstance(field.dtype, PsStructType): assert expected_strides[-1] == 1 expected_strides = expected_strides[:-1] - + actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides) if expected_strides != actual_strides: raise ValueError( @@ -149,28 +147,38 @@ class CupyKernelWrapper(KernelWrapper): arr: cp.ndarray for kparam in self._kfunc.parameters: - match kparam: - case FieldPointerParam(_, dtype, field): - arr = kwargs[field.name] - if arr.dtype != field.dtype.numpy_dtype: - raise JitError( - f"Data type mismatch at array argument {field.name}:" - f"Expected {field.dtype}, got {arr.dtype}" - ) - check_shape(kparam, arr) - args.append(arr) - - case FieldShapeParam(name, dtype, field, coord): - arr = kwargs[field.name] - add_arg(name, arr.shape[coord], dtype) - - case FieldStrideParam(name, dtype, field, coord): - arr = kwargs[field.name] - add_arg(name, arr.strides[coord] // arr.dtype.itemsize, dtype) - - case KernelParameter(name, dtype): - val: Any = kwargs[name] - add_arg(name, val, dtype) + if kparam.is_field_parameter: + # Determine field-associated data to pass in + for prop in kparam.properties: + match prop: + case FieldBasePtr(field): + arr = kwargs[field.name] + if arr.dtype != field.dtype.numpy_dtype: + raise JitError( + f"Data type mismatch at array argument {field.name}:" + f"Expected {field.dtype}, got {arr.dtype}" + ) + check_shape(kparam, arr) + args.append(arr) + break + + case FieldShape(field, coord): + arr = kwargs[field.name] + add_arg(kparam.name, arr.shape[coord], kparam.dtype) + break + + case FieldStride(field, coord): + arr = kwargs[field.name] + add_arg( + kparam.name, + arr.strides[coord] // arr.dtype.itemsize, + kparam.dtype, + ) + break + else: + # scalar parameter + val: Any = kwargs[kparam.name] + add_arg(kparam.name, val, kparam.dtype) # Determine launch grid from ..ast.expressions import evaluate_expression diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 0dc60b1b1..2462e5e66 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -8,7 +8,7 @@ from ..ast import PsAstNode from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr from ..ast.structural import PsLoop, PsBlock, PsAssignment -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..constants import PsConstant from .context import KernelCreationContext @@ -170,6 +170,8 @@ class AstFactory: raise ValueError( "Cannot parse a slice with `stop == None` if no normalization limit is given" ) + + assert stop is not None # for mypy return start, stop, step diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 73e3c70cc..839b8fd98 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -9,14 +9,14 @@ from ...defaults import DEFAULTS from ...field import Field, FieldType from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType -from ..symbols import PsSymbol -from ..arrays import PsLinearizedArray +from ..memory import PsSymbol, PsBuffer +from ..properties import FieldShape, FieldStride +from ..constants import PsConstant from ...types import ( PsType, PsIntegerType, PsNumericType, - PsScalarType, - PsStructType, + PsPointerType, deconstify, ) from ..constraints import KernelParamsConstraint @@ -221,64 +221,14 @@ class KernelCreationContext: else: return - arr_shape: list[str | int] | None = None - arr_strides: list[str | int] | None = None - - def normalize_type(s: TypedSymbol) -> PsIntegerType: - match s.dtype: - case DynamicType.INDEX_TYPE: - return self.index_dtype - case DynamicType.NUMERIC_TYPE: - if isinstance(self.default_dtype, PsIntegerType): - return self.default_dtype - else: - raise KernelConstraintsError( - f"Cannot use non-integer default numeric type {self.default_dtype} " - f"in field indexing symbol {s}." - ) - case PsIntegerType(): - return deconstify(s.dtype) - case _: - raise KernelConstraintsError( - f"Invalid data type for field indexing symbol {s}: {s.dtype}" - ) - - # Check field constraints and add to collection + # Check field constraints, create buffer, and add them to the collection match field.field_type: case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX: + buf = self._create_regular_field_buffer(field) self._fields_collection.domain_fields.add(field) case FieldType.BUFFER: - if field.spatial_dimensions != 1: - raise KernelConstraintsError( - f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " - "Buffer fields must be one-dimensional." - ) - - if field.index_dimensions > 1: - raise KernelConstraintsError( - f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " - "Buffer fields can have at most one index dimension." - ) - - num_entries = field.index_shape[0] if field.index_shape else 1 - if not isinstance(num_entries, int): - raise KernelConstraintsError( - f"Invalid index shape of buffer field {field.name}: {num_entries}. " - "Buffer fields cannot have variable index shape." - ) - - buffer_len = field.spatial_shape[0] - - if isinstance(buffer_len, TypedSymbol): - idx_type = normalize_type(buffer_len) - arr_shape = [buffer_len.name, num_entries] - else: - idx_type = DEFAULTS.index_dtype - arr_shape = [buffer_len, num_entries] - - arr_strides = [num_entries, 1] - + buf = self._create_buffer_field_buffer(field) self._fields_collection.buffer_fields.add(field) case FieldType.INDEXED: @@ -287,67 +237,24 @@ class KernelCreationContext: f"Invalid spatial shape of index field {field.name}: {field.spatial_dimensions}. " "Index fields must be one-dimensional." ) + buf = self._create_regular_field_buffer(field) self._fields_collection.index_fields.add(field) case FieldType.CUSTOM: + buf = self._create_regular_field_buffer(field) self._fields_collection.custom_fields.add(field) case _: assert False, "unreachable code" - # For non-buffer fields, determine shape and strides - - if arr_shape is None: - idx_types = set( - normalize_type(s) - for s in chain(field.shape, field.strides) - if isinstance(s, TypedSymbol) - ) - - if len(idx_types) > 1: - raise KernelConstraintsError( - f"Multiple incompatible types found in index symbols of field {field}: " - f"{idx_types}" - ) - idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype - - arr_shape = [ - (s.name if isinstance(s, TypedSymbol) else s) for s in field.shape - ] - - arr_strides = [ - (s.name if isinstance(s, TypedSymbol) else s) for s in field.strides - ] - - # The frontend doesn't quite agree with itself on how to model - # fields with trivial index dimensions. Sometimes the index_shape is empty, - # sometimes its (1,). This is canonicalized here. - if not field.index_shape: - arr_shape += [1] - arr_strides += [1] - - # Add array - assert arr_strides is not None - assert idx_type is not None - - assert isinstance(field.dtype, (PsScalarType, PsStructType)) - element_type = field.dtype - - arr = PsLinearizedArray( - field.name, element_type, arr_shape, arr_strides, idx_type - ) - - self._fields_and_arrays[field.name] = FieldArrayPair(field, arr) - for symb in chain([arr.base_pointer], arr.shape, arr.strides): - if isinstance(symb, PsSymbol): - self.add_symbol(symb) + self._fields_and_arrays[field.name] = FieldArrayPair(field, buf) @property - def arrays(self) -> Iterable[PsLinearizedArray]: + def arrays(self) -> Iterable[PsBuffer]: # return self._fields_and_arrays.values() yield from (item.array for item in self._fields_and_arrays.values()) - def get_array(self, field: Field) -> PsLinearizedArray: + def get_buffer(self, field: Field) -> PsBuffer: """Retrieve the underlying array for a given field. If the given field was not previously registered using `add_field`, @@ -393,3 +300,114 @@ class KernelCreationContext: def require_header(self, header: str): self._req_headers.add(header) + + # ----------- Internals --------------------------------------------------------------------- + + def _normalize_type(self, s: TypedSymbol) -> PsIntegerType: + match s.dtype: + case DynamicType.INDEX_TYPE: + return self.index_dtype + case DynamicType.NUMERIC_TYPE: + if isinstance(self.default_dtype, PsIntegerType): + return self.default_dtype + else: + raise KernelConstraintsError( + f"Cannot use non-integer default numeric type {self.default_dtype} " + f"in field indexing symbol {s}." + ) + case PsIntegerType(): + return deconstify(s.dtype) + case _: + raise KernelConstraintsError( + f"Invalid data type for field indexing symbol {s}: {s.dtype}" + ) + + def _create_regular_field_buffer(self, field: Field) -> PsBuffer: + idx_types = set( + self._normalize_type(s) + for s in chain(field.shape, field.strides) + if isinstance(s, TypedSymbol) + ) + + if len(idx_types) > 1: + raise KernelConstraintsError( + f"Multiple incompatible types found in index symbols of field {field}: " + f"{idx_types}" + ) + + idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype + + def convert_size(s: TypedSymbol | int) -> PsSymbol | PsConstant: + if isinstance(s, TypedSymbol): + return self.get_symbol(s.name, idx_type) + else: + return PsConstant(s, idx_type) + + buf_shape = [convert_size(s) for s in field.shape] + buf_strides = [convert_size(s) for s in field.strides] + + # The frontend doesn't quite agree with itself on how to model + # fields with trivial index dimensions. Sometimes the index_shape is empty, + # sometimes its (1,). This is canonicalized here. + if not field.index_shape: + buf_shape += [convert_size(1)] + buf_strides += [convert_size(1)] + + for i, size in enumerate(buf_shape): + if isinstance(size, PsSymbol): + size.add_property(FieldShape(field, i)) + + for i, stride in enumerate(buf_strides): + if isinstance(stride, PsSymbol): + stride.add_property(FieldStride(field, i)) + + base_ptr = self.get_symbol( + DEFAULTS.field_pointer_name(field.name), + PsPointerType(field.dtype, restrict=True), + ) + + return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) + + def _create_buffer_field_buffer(self, field: Field) -> PsBuffer: + if field.spatial_dimensions != 1: + raise KernelConstraintsError( + f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields must be one-dimensional." + ) + + if field.index_dimensions > 1: + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. " + "Buffer fields can have at most one index dimension." + ) + + num_entries = field.index_shape[0] if field.index_shape else 1 + if not isinstance(num_entries, int): + raise KernelConstraintsError( + f"Invalid index shape of buffer field {field.name}: {num_entries}. " + "Buffer fields cannot have variable index shape." + ) + + buffer_len = field.spatial_shape[0] + buf_shape: list[PsSymbol | PsConstant] + + if isinstance(buffer_len, TypedSymbol): + idx_type = self._normalize_type(buffer_len) + len_symb = self.get_symbol(buffer_len.name, idx_type) + len_symb.add_property(FieldShape(field, 0)) + buf_shape = [len_symb, PsConstant(num_entries, idx_type)] + else: + idx_type = DEFAULTS.index_dtype + buf_shape = [ + PsConstant(buffer_len, idx_type), + PsConstant(num_entries, idx_type), + ] + + buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)] + + base_ptr = self.get_symbol( + DEFAULTS.field_pointer_name(field.name), + PsPointerType(field.dtype, restrict=True), + ) + + return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 0ae5a0d1b..bdc8f1133 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -28,7 +28,7 @@ from ..ast.structural import ( PsSymbolExpr, ) from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsArrayInitList, PsBitwiseAnd, PsBitwiseOr, @@ -43,7 +43,7 @@ from ..ast.expressions import ( PsLookup, PsRightShift, PsSubscript, - PsVectorArrayAccess, + PsVectorMemAcc, PsTernary, PsRel, PsEq, @@ -158,7 +158,7 @@ class FreezeExpressions: if isinstance(lhs, PsSymbolExpr): return PsDeclaration(lhs, rhs) - elif isinstance(lhs, (PsArrayAccess, PsLookup, PsVectorArrayAccess)): # todo + elif isinstance(lhs, (PsBufferAcc, PsLookup, PsVectorMemAcc)): # todo return PsAssignment(lhs, rhs) else: raise FreezeError( @@ -309,7 +309,7 @@ class FreezeExpressions: def map_Access(self, access: Field.Access): field = access.field - array = self._ctx.get_array(field) + array = self._ctx.get_buffer(field) ptr = array.base_pointer offsets: list[PsExpression] = [ @@ -363,18 +363,11 @@ class FreezeExpressions: # For canonical representation, there must always be at least one index dimension indices = [PsExpression.make(PsConstant(0))] - summands = tuple( - idx * PsExpression.make(stride) - for idx, stride in zip(offsets + indices, array.strides, strict=True) - ) - - index = summands[0] if len(summands) == 1 else reduce(add, summands) - if struct_member_name is not None: # Produce a Lookup here, don't check yet if the member name is valid. That's the typifier's job. - return PsLookup(PsArrayAccess(ptr, index), struct_member_name) + return PsLookup(PsBufferAcc(ptr, offsets + indices), struct_member_name) else: - return PsArrayAccess(ptr, index) + return PsBufferAcc(ptr, offsets + indices) def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess): facc = self.visit_expr(acc.access) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 8175fffed..bae0328e4 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -9,10 +9,9 @@ from ...defaults import DEFAULTS from ...simp import AssignmentCollection from ...field import Field, FieldType -from ..symbols import PsSymbol +from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem -from ..arrays import PsLinearizedArray from ..ast.util import failing_cast from ...types import PsStructType, constify from ..exceptions import PsInputError, KernelConstraintsError @@ -74,7 +73,7 @@ class FullIterationSpace(IterationSpace): ) -> FullIterationSpace: """Create an iteration space over an archetype field with ghost layers.""" - archetype_array = ctx.get_array(archetype_field) + archetype_array = ctx.get_buffer(archetype_field) dim = archetype_field.spatial_dimensions counters = [ @@ -142,7 +141,7 @@ class FullIterationSpace(IterationSpace): archetype_size: tuple[PsSymbol | PsConstant | None, ...] if archetype_field is not None: - archetype_array = ctx.get_array(archetype_field) + archetype_array = ctx.get_buffer(archetype_field) if archetype_field.spatial_dimensions != dim: raise ValueError( @@ -281,7 +280,7 @@ class SparseIterationSpace(IterationSpace): def __init__( self, spatial_indices: Sequence[PsSymbol], - index_list: PsLinearizedArray, + index_list: PsBuffer, coordinate_members: Sequence[PsStructType.Member], sparse_counter: PsSymbol, ): @@ -291,7 +290,7 @@ class SparseIterationSpace(IterationSpace): self._sparse_counter = sparse_counter @property - def index_list(self) -> PsLinearizedArray: + def index_list(self) -> PsBuffer: return self._index_list @property @@ -365,7 +364,7 @@ def create_sparse_iteration_space( # Determine index field if index_field is not None: - idx_arr = ctx.get_array(index_field) + idx_arr = ctx.get_buffer(index_field) idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type) for coord in coord_members: diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 819d4a12b..c8fad68f1 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -23,10 +23,11 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, + PsStatement, PsEmptyLeafMixIn, ) from ..ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsArrayInitList, PsBinOp, PsIntOpTrait, @@ -301,6 +302,12 @@ class Typifier: for s in statements: self.visit(s) + case PsStatement(expr): + tc = TypeContext() + self.visit_expr(expr, tc) + if tc.target_type is None: + tc.apply_dtype(self._ctx.default_dtype) + case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList): # Special treatment for array declarations assert isinstance(lhs, PsSymbolExpr) @@ -337,6 +344,8 @@ class Typifier: decl_tc.apply_dtype( PsArrayType(items_tc.target_type, rhs.shape), rhs ) + else: + decl_tc.infer_dtype(rhs) case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs): # Only if the LHS is an untyped symbol, infer its type from the RHS @@ -413,9 +422,10 @@ class Typifier: case PsLiteralExpr(lit): tc.apply_dtype(lit.dtype, expr) - case PsArrayAccess(bptr, idx): - tc.apply_dtype(bptr.array.element_type, expr) - self._handle_idx(idx) + case PsBufferAcc(_, indices): + tc.apply_dtype(expr.buffer.element_type, expr) + for idx in indices: + self._handle_idx(idx) case PsMemAcc(ptr, offset): ptr_tc = TypeContext() @@ -464,7 +474,7 @@ class Typifier: self._handle_idx(idx) case PsAddressOf(arg): - if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsLookup)): + if not isinstance(arg, (PsSymbolExpr, PsSubscript, PsMemAcc, PsBufferAcc, PsLookup)): raise TypificationError( f"Illegal expression below AddressOf operator: {arg}" ) @@ -481,7 +491,7 @@ class Typifier: match arg: case PsSymbolExpr(s): pointed_to_type = s.get_dtype() - case PsSubscript(ptr, _) | PsMemAcc(ptr, _): + case PsSubscript(ptr, _) | PsMemAcc(ptr, _) | PsBufferAcc(ptr, _): arr_type = ptr.get_dtype() assert isinstance(arr_type, PsDereferencableType) pointed_to_type = arr_type.base_type diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index a3213350e..9275c55ec 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -1,15 +1,21 @@ from __future__ import annotations from warnings import warn -from abc import ABC from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING +from itertools import chain from .._deprecation import _deprecated from .ast.structural import PsBlock from .ast.analysis import collect_required_headers, collect_undefined_symbols -from .arrays import PsArrayShapeSymbol, PsArrayStrideSymbol, PsArrayBasePointer -from .symbols import PsSymbol +from .memory import PsSymbol +from .properties import ( + PsSymbolProperty, + _FieldProperty, + FieldShape, + FieldStride, + FieldBasePtr, +) from .kernelcreation.context import KernelCreationContext from .platforms import Platform, GpuThreadsRange @@ -25,11 +31,29 @@ if TYPE_CHECKING: class KernelParameter: - __match_args__ = ("name", "dtype") + """Parameter to a `KernelFunction`.""" - def __init__(self, name: str, dtype: PsType): + __match_args__ = ("name", "dtype", "properties") + + def __init__( + self, name: str, dtype: PsType, properties: Iterable[PsSymbolProperty] = () + ): self._name = name self._dtype = dtype + self._properties: frozenset[PsSymbolProperty] = ( + frozenset(properties) if properties is not None else frozenset() + ) + self._fields: tuple[Field, ...] = tuple( + sorted( + set( + p.field # type: ignore + for p in filter( + lambda p: isinstance(p, _FieldProperty), self._properties + ) + ), + key=lambda f: f.name + ) + ) @property def name(self): @@ -40,8 +64,9 @@ class KernelParameter: return self._dtype def _hashable_contents(self): - return (self._name, self._dtype) + return (self._name, self._dtype, self._properties) + # TODO: Need? def __hash__(self) -> int: return hash(self._hashable_contents()) @@ -64,110 +89,63 @@ class KernelParameter: def symbol(self) -> TypedSymbol: return TypedSymbol(self.name, self.dtype) + @property + def fields(self) -> tuple[Field, ...]: + """Set of fields associated with this parameter.""" + return self._fields + + def get_properties( + self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] + ) -> set[PsSymbolProperty]: + """Retrieve all properties of the given type(s) attached to this parameter""" + return set(filter(lambda p: isinstance(p, prop_type), self._properties)) + + @property + def properties(self) -> frozenset[PsSymbolProperty]: + return self._properties + @property def is_field_parameter(self) -> bool: - warn( - "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldParameter)` instead.", - DeprecationWarning, - ) - return isinstance(self, FieldParameter) + return bool(self._fields) + + # Deprecated legacy properties + # These are kept mostly for the legacy waLBerla code generation system @property def is_field_pointer(self) -> bool: warn( "`is_field_pointer` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldPointerParam)` instead.", + "Use `param.get_properties(FieldBasePtr)` instead.", DeprecationWarning, ) - return isinstance(self, FieldPointerParam) + return bool(self.get_properties(FieldBasePtr)) @property def is_field_stride(self) -> bool: warn( "`is_field_stride` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldStrideParam)` instead.", + "Use `param.get_properties(FieldStride)` instead.", DeprecationWarning, ) - return isinstance(self, FieldStrideParam) + return bool(self.get_properties(FieldStride)) @property def is_field_shape(self) -> bool: warn( "`is_field_shape` is deprecated and will be removed in a future version of pystencils. " - "Use `isinstance(param, FieldShapeParam)` instead.", - DeprecationWarning, - ) - return isinstance(self, FieldShapeParam) - - -class FieldParameter(KernelParameter, ABC): - __match_args__ = KernelParameter.__match_args__ + ("field",) - - def __init__(self, name: str, dtype: PsType, field: Field): - super().__init__(name, dtype) - self._field = field - - @property - def field(self): - return self._field - - @property - def fields(self): - warn( - "`fields` is deprecated and will be removed in a future version of pystencils. " - "In pystencils >= 2.0, field parameters are only associated with a single field." - "Use the `field` property instead.", + "Use `param.get_properties(FieldShape)` instead.", DeprecationWarning, ) - return [self._field] + return bool(self.get_properties(FieldShape)) @property def field_name(self) -> str: warn( "`field_name` is deprecated and will be removed in a future version of pystencils. " - "Use `field.name` instead.", + "Use `param.fields[0].name` instead.", DeprecationWarning, ) - return self._field.name - - def _hashable_contents(self): - return super()._hashable_contents() + (self._field,) - - -class FieldShapeParam(FieldParameter): - __match_args__ = FieldParameter.__match_args__ + ("coordinate",) - - def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int): - super().__init__(name, dtype, field) - self._coordinate = coordinate - - @property - def coordinate(self): - return self._coordinate - - def _hashable_contents(self): - return super()._hashable_contents() + (self._coordinate,) - - -class FieldStrideParam(FieldParameter): - __match_args__ = FieldParameter.__match_args__ + ("coordinate",) - - def __init__(self, name: str, dtype: PsType, field: Field, coordinate: int): - super().__init__(name, dtype, field) - self._coordinate = coordinate - - @property - def coordinate(self): - return self._coordinate - - def _hashable_contents(self): - return super()._hashable_contents() + (self._coordinate,) - - -class FieldPointerParam(FieldParameter): - def __init__(self, name: str, dtype: PsType, field: Field): - super().__init__(name, dtype, field) + return self._fields[0].name class KernelFunction: @@ -236,7 +214,7 @@ class KernelFunction: return self.parameters def get_fields(self) -> set[Field]: - return set(p.field for p in self._params if isinstance(p, FieldParameter)) + return set(chain.from_iterable(p.fields for p in self._params)) @property def fields_accessed(self) -> set[Field]: @@ -333,19 +311,19 @@ def create_gpu_kernel_function( def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): params: list[KernelParameter] = [] + + from pystencils.backend.memory import BufferBasePtr + for symb in symbols: - match symb: - case PsArrayShapeSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldShapeParam(name, symb.get_dtype(), field, coord)) - case PsArrayStrideSymbol(name, _, arr, coord): - field = ctx.find_field(arr.name) - params.append(FieldStrideParam(name, symb.get_dtype(), field, coord)) - case PsArrayBasePointer(name, _, arr): - field = ctx.find_field(arr.name) - params.append(FieldPointerParam(name, symb.get_dtype(), field)) - case PsSymbol(name, _): - params.append(KernelParameter(name, symb.get_dtype())) + props: set[PsSymbolProperty] = set() + for prop in symb.properties: + match prop: + case FieldShape() | FieldStride(): + props.add(prop) + case BufferBasePtr(buf): + field = ctx.find_field(buf.name) + props.add(FieldBasePtr(field)) + params.append(KernelParameter(symb.name, symb.get_dtype(), props)) params.sort(key=lambda p: p.name) return params diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py new file mode 100644 index 000000000..9b72a4e43 --- /dev/null +++ b/src/pystencils/backend/memory.py @@ -0,0 +1,198 @@ +from __future__ import annotations +from typing import Sequence +from itertools import chain +from dataclasses import dataclass + +from ..types import PsType, PsTypeError, deconstify, PsIntegerType, PsPointerType +from .exceptions import PsInternalCompilerError +from .constants import PsConstant +from .properties import PsSymbolProperty, UniqueSymbolProperty + + +class PsSymbol: + """A mutable symbol with name and data type. + + Do not create objects of this class directly unless you know what you are doing; + instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`. + This way, the context can keep track of all symbols used in the translation run, + and uniqueness of symbols is ensured. + """ + + __match_args__ = ("name", "dtype") + + def __init__(self, name: str, dtype: PsType | None = None): + self._name = name + self._dtype = dtype + self._properties: set[PsSymbolProperty] = set() + + @property + def name(self) -> str: + return self._name + + @property + def dtype(self) -> PsType | None: + return self._dtype + + @dtype.setter + def dtype(self, value: PsType): + self._dtype = value + + def apply_dtype(self, dtype: PsType): + """Apply the given data type to this symbol, + raising a TypeError if it conflicts with a previously set data type.""" + + if self._dtype is not None and self._dtype != dtype: + raise PsTypeError( + f"Incompatible symbol data types: {self._dtype} and {dtype}" + ) + + self._dtype = dtype + + def get_dtype(self) -> PsType: + if self._dtype is None: + raise PsInternalCompilerError( + f"Symbol {self.name} had no type assigned yet" + ) + return self._dtype + + @property + def properties(self) -> frozenset[PsSymbolProperty]: + """Set of properties attached to this symbol""" + return frozenset(self._properties) + + def get_properties( + self, prop_type: type[PsSymbolProperty] + ) -> set[PsSymbolProperty]: + """Retrieve all properties of the given type attached to this symbol""" + return set(filter(lambda p: isinstance(p, prop_type), self._properties)) + + def add_property(self, property: PsSymbolProperty): + """Attach a property to this symbol""" + if isinstance(property, UniqueSymbolProperty) and not self.get_properties( + type(property) + ) <= {property}: + raise ValueError( + f"Cannot add second instance of unique property {type(property)} to symbol {self._name}." + ) + + self._properties.add(property) + + def remove_property(self, property: PsSymbolProperty): + """Remove a property from this symbol. Does nothing if the property is not attached.""" + self._properties.discard(property) + + def __str__(self) -> str: + dtype_str = "<untyped>" if self._dtype is None else str(self._dtype) + return f"{self._name}: {dtype_str}" + + def __repr__(self) -> str: + return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})" + + +@dataclass(frozen=True) +class BufferBasePtr(UniqueSymbolProperty): + """Symbol acts as a base pointer to a buffer.""" + + buffer: PsBuffer + + +class PsBuffer: + """N-dimensional contiguous linearized buffer in heap memory. + + `PsBuffer` models the memory buffers underlying the `Field` class + to the backend. Each buffer represents a contiguous block of memory + that is non-aliased and disjoint from all other buffers. + + Buffer shape and stride information are given either as constants or as symbols. + All indexing expressions must have the same data type, which will be selected as the buffer's + `index_dtype`. + + Each buffer has at least one base pointer, which can be retrieved via the `PsBuffer.base_pointer` + property. + """ + + def __init__( + self, + name: str, + element_type: PsType, + base_ptr: PsSymbol, + shape: Sequence[PsSymbol | PsConstant], + strides: Sequence[PsSymbol | PsConstant], + ): + bptr_type = base_ptr.get_dtype() + + if not isinstance(bptr_type, PsPointerType): + raise ValueError( + f"Type of buffer base pointer {base_ptr} was not a pointer type: {bptr_type}" + ) + + if bptr_type.base_type != element_type: + raise ValueError( + f"Base type of primary buffer base pointer {base_ptr} " + f"did not equal buffer element type {element_type}." + ) + + if len(shape) != len(strides): + raise ValueError("Buffer shape and stride tuples must have the same length") + + idx_types: set[PsType] = set( + deconstify(s.get_dtype()) for s in chain(shape, strides) + ) + if len(idx_types) > 1: + raise ValueError( + f"Conflicting data types in indexing symbols to buffer {name}: {idx_types}" + ) + + idx_dtype = idx_types.pop() + if not isinstance(idx_dtype, PsIntegerType): + raise ValueError( + f"Invalid index data type for buffer {name}: {idx_dtype}. Must be an integer type." + ) + + self._name = name + self._element_type = element_type + self._index_dtype = idx_dtype + + self._shape = tuple(shape) + self._strides = tuple(strides) + + base_ptr.add_property(BufferBasePtr(self)) + self._base_ptr = base_ptr + + @property + def name(self): + """The buffer's name""" + return self._name + + @property + def base_pointer(self) -> PsSymbol: + """Primary base pointer""" + return self._base_ptr + + @property + def shape(self) -> tuple[PsSymbol | PsConstant, ...]: + """Buffer shape symbols and/or constants""" + return self._shape + + @property + def strides(self) -> tuple[PsSymbol | PsConstant, ...]: + """Buffer stride symbols and/or constants""" + return self._strides + + @property + def dim(self) -> int: + """Dimensionality of this buffer""" + return len(self._shape) + + @property + def index_type(self) -> PsIntegerType: + """Index data type of this buffer; i.e. data type of its shape and stride symbols""" + return self._index_dtype + + @property + def element_type(self) -> PsType: + """Element type of this buffer""" + return self._element_type + + def __repr__(self) -> str: + return f"PsBuffer({self._name}: {self.element_type}[{len(self.shape)}D])" diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index c89d22788..323dcc5a9 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -7,6 +7,7 @@ from ..kernelcreation import ( IterationSpace, FullIterationSpace, SparseIterationSpace, + AstFactory ) from ..kernelcreation.context import KernelCreationContext @@ -17,7 +18,7 @@ from ..ast.expressions import ( PsCast, PsCall, PsLookup, - PsArrayAccess, + PsBufferAcc, ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType @@ -159,6 +160,7 @@ class CudaPlatform(GenericGpu): def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace ) -> tuple[PsBlock, GpuThreadsRange]: + factory = AstFactory(self._ctx) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) sparse_ctr = PsExpression.make(ispace.sparse_counter) @@ -171,9 +173,9 @@ class CudaPlatform(GenericGpu): PsDeclaration( PsExpression.make(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr, factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index a1505e672..f8cae89fc 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -21,8 +21,8 @@ from ..ast.structural import PsDeclaration, PsLoop, PsBlock from ..ast.expressions import ( PsSymbolExpr, PsExpression, - PsArrayAccess, - PsVectorArrayAccess, + PsBufferAcc, + PsVectorMemAcc, PsLookup, PsGe, PsLe, @@ -124,13 +124,15 @@ class GenericCpu(Platform): return PsBlock([loops]) def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace): + factory = AstFactory(self._ctx) + mappings = [ PsDeclaration( PsSymbolExpr(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, - PsExpression.make(ispace.sparse_counter), + (PsExpression.make(ispace.sparse_counter), factory.parse_index(0)), ), coord.name, ), @@ -173,11 +175,11 @@ class GenericVectorCpu(GenericCpu, ABC): or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: + def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: """Return an expression intrinsically performing a vector load, or raise an `MaterializationError` if not supported.""" @abstractmethod - def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: """Return an expression intrinsically performing a vector store, or raise an `MaterializationError` if not supported.""" diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index fa42ed021..ec5e7eda0 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -16,11 +16,11 @@ from ..ast.expressions import ( PsLe, PsTernary, PsLookup, - PsArrayAccess + PsBufferAcc ) from ..extensions.cpp import CppMethodCall -from ..kernelcreation.context import KernelCreationContext +from ..kernelcreation import KernelCreationContext, AstFactory from ..constants import PsConstant from .generic_gpu import GenericGpu, GpuThreadsRange from ..exceptions import MaterializationError @@ -147,6 +147,8 @@ class SyclPlatform(GenericGpu): def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace ) -> tuple[PsBlock, GpuThreadsRange]: + factory = AstFactory(self._ctx) + id_type = PsCustomType("sycl::id< 1 >", const=True) id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) @@ -163,9 +165,9 @@ class SyclPlatform(GenericGpu): PsDeclaration( PsExpression.make(ctr), PsLookup( - PsArrayAccess( + PsBufferAcc( ispace.index_list.base_pointer, - sparse_ctr, + (sparse_ctr, factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 0c6f6883d..33838df08 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -5,7 +5,7 @@ from typing import Sequence from ..ast.expressions import ( PsExpression, - PsVectorArrayAccess, + PsVectorMemAcc, PsAddressOf, PsMemAcc, ) @@ -141,20 +141,20 @@ class X86VectorCpu(GenericVectorCpu): func = _x86_op_intrin(self._vector_arch, op, vtype) return func(*args) - def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression: + def vector_load(self, acc: PsVectorMemAcc) -> PsExpression: if acc.stride == 1: load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) return load_func( - PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)) + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)) ) else: raise NotImplementedError("Gather loads not implemented yet.") - def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression: + def vector_store(self, acc: PsVectorMemAcc, arg: PsExpression) -> PsExpression: if acc.stride == 1: store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) return store_func( - PsAddressOf(PsMemAcc(PsExpression.make(acc.base_ptr), acc.index)), + PsAddressOf(PsMemAcc(acc.pointer, acc.offset)), arg, ) else: diff --git a/src/pystencils/backend/properties.py b/src/pystencils/backend/properties.py new file mode 100644 index 000000000..d377fb3d3 --- /dev/null +++ b/src/pystencils/backend/properties.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from dataclasses import dataclass + +from ..field import Field + + +@dataclass(frozen=True) +class PsSymbolProperty: + """Base class for symbol properties, which can be used to add additional information to symbols""" + + +@dataclass(frozen=True) +class UniqueSymbolProperty(PsSymbolProperty): + """Base class for unique properties, of which only one instance may be registered at a time.""" + + +@dataclass(frozen=True) +class FieldShape(PsSymbolProperty): + """Symbol acts as a shape parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldStride(PsSymbolProperty): + """Symbol acts as a stride parameter to a field.""" + + field: Field + coordinate: int + + +@dataclass(frozen=True) +class FieldBasePtr(UniqueSymbolProperty): + """Symbol acts as a base pointer to a field.""" + + field: Field + + +FieldProperty = FieldShape | FieldStride | FieldBasePtr +_FieldProperty = (FieldShape, FieldStride, FieldBasePtr) diff --git a/src/pystencils/backend/symbols.py b/src/pystencils/backend/symbols.py deleted file mode 100644 index b007e3fcf..000000000 --- a/src/pystencils/backend/symbols.py +++ /dev/null @@ -1,55 +0,0 @@ -from ..types import PsType, PsTypeError -from .exceptions import PsInternalCompilerError - - -class PsSymbol: - """A mutable symbol with name and data type. - - Do not create objects of this class directly unless you know what you are doing; - instead obtain them from a `KernelCreationContext` through `KernelCreationContext.get_symbol`. - This way, the context can keep track of all symbols used in the translation run, - and uniqueness of symbols is ensured. - """ - - __match_args__ = ("name", "dtype") - - def __init__(self, name: str, dtype: PsType | None = None): - self._name = name - self._dtype = dtype - - @property - def name(self) -> str: - return self._name - - @property - def dtype(self) -> PsType | None: - return self._dtype - - @dtype.setter - def dtype(self, value: PsType): - self._dtype = value - - def apply_dtype(self, dtype: PsType): - """Apply the given data type to this symbol, - raising a TypeError if it conflicts with a previously set data type.""" - - if self._dtype is not None and self._dtype != dtype: - raise PsTypeError( - f"Incompatible symbol data types: {self._dtype} and {dtype}" - ) - - self._dtype = dtype - - def get_dtype(self) -> PsType: - if self._dtype is None: - raise PsInternalCompilerError( - f"Symbol {self.name} had no type assigned yet" - ) - return self._dtype - - def __str__(self) -> str: - dtype_str = "<untyped>" if self._dtype is None else str(self._dtype) - return f"{self._name}: {dtype_str}" - - def __repr__(self) -> str: - return f"PsSymbol({self._name}, {self._dtype})" diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 88ad9348f..7375af618 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -69,7 +69,7 @@ Loop Reshaping Transformations Code Lowering and Materialization --------------------------------- -.. autoclass:: EraseAnonymousStructTypes +.. autoclass:: LowerToC :members: __call__ .. autoclass:: SelectFunctions @@ -84,7 +84,7 @@ from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP -from .erase_anonymous_structs import EraseAnonymousStructTypes +from .lower_to_c import LowerToC from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics @@ -98,7 +98,7 @@ __all__ = [ "InsertPragmasAtLoops", "LoopPragma", "AddOpenMP", - "EraseAnonymousStructTypes", + "LowerToC", "SelectFunctions", "MaterializeVectorIntrinsics", ] diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py index b21fd115f..2cf9bcf0c 100644 --- a/src/pystencils/backend/transformations/canonical_clone.py +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -1,7 +1,7 @@ from typing import TypeVar, cast from ..kernelcreation import KernelCreationContext -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ..ast import PsAstNode diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index e55807ef4..f5b356432 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -1,5 +1,5 @@ from ..kernelcreation import KernelCreationContext -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..exceptions import PsInternalCompilerError from ..ast import PsAstNode diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index bd3b2bb58..222f4a378 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -34,7 +34,7 @@ from ..ast.expressions import ( from ..ast.util import AstEqWrapper from ..constants import PsConstant -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..functions import PsMathFunction from ...types import ( PsIntegerType, diff --git a/src/pystencils/backend/transformations/erase_anonymous_structs.py b/src/pystencils/backend/transformations/erase_anonymous_structs.py deleted file mode 100644 index 7404abd94..000000000 --- a/src/pystencils/backend/transformations/erase_anonymous_structs.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -from ..kernelcreation.context import KernelCreationContext - -from ..constants import PsConstant -from ..ast.structural import PsAstNode -from ..ast.expressions import ( - PsArrayAccess, - PsLookup, - PsExpression, - PsMemAcc, - PsAddressOf, - PsCast, -) -from ..kernelcreation import Typifier -from ..arrays import PsArrayBasePointer, TypeErasedBasePointer -from ...types import PsStructType, PsPointerType - - -class EraseAnonymousStructTypes: - """Lower anonymous struct arrays to a byte-array representation. - - For arrays whose element type is an anonymous struct, the struct type is erased from the base pointer, - making it a pointer to uint8_t. - Member lookups on accesses into these arrays are then transformed using type casts. - """ - - def __init__(self, ctx: KernelCreationContext) -> None: - self._ctx = ctx - - self._substitutions: dict[PsArrayBasePointer, TypeErasedBasePointer] = dict() - - def __call__(self, node: PsAstNode) -> PsAstNode: - self._substitutions = dict() - - # Check if AST traversal is even necessary - if not any( - (isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous) - for arr in self._ctx.arrays - ): - return node - - node = self.visit(node) - - for old, new in self._substitutions.items(): - self._ctx.replace_symbol(old, new) - - return node - - def visit(self, node: PsAstNode) -> PsAstNode: - match node: - case PsLookup(): - # descend into expr - return self.handle_lookup(node) - case _: - node.children = [self.visit(c) for c in node.children] - - return node - - def handle_lookup(self, lookup: PsLookup) -> PsExpression: - aggr = lookup.aggregate - if not isinstance(aggr, PsArrayAccess): - return lookup - - arr = aggr.array - if ( - not isinstance(arr.element_type, PsStructType) - or not arr.element_type.anonymous - ): - return lookup - - struct_type = arr.element_type - struct_size = struct_type.itemsize - - bp = aggr.base_ptr - - # Need to keep track of base pointers already seen, since symbols must be unique - if bp not in self._substitutions: - type_erased_bp = TypeErasedBasePointer(bp.name, arr) - self._substitutions[bp] = type_erased_bp - else: - type_erased_bp = self._substitutions[bp] - - base_index = aggr.index * PsExpression.make( - PsConstant(struct_size, self._ctx.index_dtype) - ) - - member_name = lookup.member_name - member = struct_type.find_member(member_name) - assert member is not None - - np_struct = struct_type.numpy_dtype - assert np_struct is not None - assert np_struct.fields is not None - member_offset = np_struct.fields[member_name][1] - - byte_index = base_index + PsExpression.make( - PsConstant(member_offset, self._ctx.index_dtype) - ) - type_erased_access = PsArrayAccess(type_erased_bp, byte_index) - - deref = PsMemAcc( - PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), - PsExpression.make(PsConstant(0)) - ) - - typify = Typifier(self._ctx) - deref = typify(deref) - return deref diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index d4dfd3d04..f0e4cc9f1 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -9,7 +9,7 @@ from ..ast.expressions import ( PsConstantExpr, PsLiteralExpr, PsCall, - PsArrayAccess, + PsBufferAcc, PsSubscript, PsLookup, PsUnOp, @@ -19,7 +19,7 @@ from ..ast.expressions import ( from ..ast.util import determine_memory_object from ...types import PsDereferencableType -from ..symbols import PsSymbol +from ..memory import PsSymbol from ..functions import PsMathFunction __all__ = ["HoistLoopInvariantDeclarations"] @@ -53,7 +53,7 @@ class HoistContext: case PsSubscript() | PsLookup(): return determine_memory_object(expr)[1] and args_invariant(expr) - case PsArrayAccess(ptr, _): + case PsBufferAcc(ptr, _): # Regular pointer derefs are never invariant, since we cannot reason about aliasing ptr_type = cast(PsDereferencableType, ptr.get_dtype()) return ptr_type.base_type.const and args_invariant(expr) diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py new file mode 100644 index 000000000..ea832355b --- /dev/null +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import cast +from functools import reduce +import operator + +from ..kernelcreation import KernelCreationContext, Typifier + +from ..constants import PsConstant +from ..memory import PsSymbol, PsBuffer, BufferBasePtr +from ..ast.structural import PsAstNode +from ..ast.expressions import ( + PsBufferAcc, + PsLookup, + PsExpression, + PsMemAcc, + PsAddressOf, + PsCast, + PsSymbolExpr, +) +from ...types import PsStructType, PsPointerType, PsUnsignedIntegerType + + +class LowerToC: + """Lower high-level IR constructs to C language concepts. + + This pass will replace a number of IR constructs that have no direct counterpart in the C language + to lower-level AST nodes. These include: + + - *Linearization of Buffer Accesses:* `PsBufferAcc` buffer accesses are linearized according to + their buffers' stride information and replaced by `PsMemAcc`. + - *Erasure of Anonymous Structs:* + For buffers whose element type is an anonymous struct, the struct type is erased from the base pointer, + making it a pointer to uint8_t. + Member lookups on accesses into these buffers are then transformed using type casts. + """ + + def __init__(self, ctx: KernelCreationContext) -> None: + self._ctx = ctx + self._typify = Typifier(ctx) + + self._substitutions: dict[PsSymbol, PsSymbol] = dict() + + def __call__(self, node: PsAstNode) -> PsAstNode: + self._substitutions = dict() + + node = self.visit(node) + + for old, new in self._substitutions.items(): + self._ctx.replace_symbol(old, new) + + return node + + def visit(self, node: PsAstNode) -> PsAstNode: + match node: + case PsBufferAcc(bptr, indices): + # Linearize + buf = node.buffer + + # Typifier allows different data types in each index + def maybe_cast(i: PsExpression): + if i.get_dtype() != buf.index_type: + return PsCast(buf.index_type, i) + else: + return i + + summands: list[PsExpression] = [ + maybe_cast(cast(PsExpression, self.visit(idx))) * PsExpression.make(stride) + for idx, stride in zip(indices, buf.strides, strict=True) + ] + + linearized_idx: PsExpression = ( + summands[0] + if len(summands) == 1 + else reduce(operator.add, summands) + ) + + mem_acc = PsMemAcc(bptr, linearized_idx) + + return self._typify.typify_expression( + mem_acc, target_type=buf.element_type + )[0] + + case PsLookup(aggr, member_name) if isinstance( + aggr, PsBufferAcc + ) and isinstance( + aggr.buffer.element_type, PsStructType + ) and aggr.buffer.element_type.anonymous: + # Need to lower this buffer-lookup + linearized_acc = self.visit(aggr) + return self._lower_anon_lookup( + cast(PsMemAcc, linearized_acc), aggr.buffer, member_name + ) + + case _: + node.children = [self.visit(c) for c in node.children] + + return node + + def _lower_anon_lookup( + self, aggr: PsMemAcc, buf: PsBuffer, member_name: str + ) -> PsExpression: + struct_type = cast(PsStructType, buf.element_type) + struct_size = struct_type.itemsize + + assert isinstance(aggr.pointer, PsSymbolExpr) + bp = aggr.pointer.symbol + bp_type = bp.get_dtype() + assert isinstance(bp_type, PsPointerType) + + # Need to keep track of base pointers already seen, since symbols must be unique + if bp not in self._substitutions: + erased_type = PsPointerType( + PsUnsignedIntegerType(8, const=bp_type.base_type.const), + const=bp_type.const, + restrict=bp_type.restrict, + ) + type_erased_bp = PsSymbol( + bp.name, + erased_type + ) + type_erased_bp.add_property(BufferBasePtr(buf)) + self._substitutions[bp] = type_erased_bp + else: + type_erased_bp = self._substitutions[bp] + + base_index = aggr.offset * PsExpression.make( + PsConstant(struct_size, self._ctx.index_dtype) + ) + + member = struct_type.find_member(member_name) + assert member is not None + + np_struct = struct_type.numpy_dtype + assert np_struct is not None + assert np_struct.fields is not None + member_offset = np_struct.fields[member_name][1] + + byte_index = base_index + PsExpression.make( + PsConstant(member_offset, self._ctx.index_dtype) + ) + type_erased_access = PsMemAcc(PsExpression.make(type_erased_bp), byte_index) + + deref = PsMemAcc( + PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)), + PsExpression.make(PsConstant(0)), + ) + + deref = self._typify(deref) + return deref diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py index 7972de069..3fb484c15 100644 --- a/src/pystencils/backend/transformations/select_intrinsics.py +++ b/src/pystencils/backend/transformations/select_intrinsics.py @@ -6,7 +6,7 @@ from ..ast.structural import PsAstNode, PsAssignment, PsStatement from ..ast.expressions import PsExpression from ...types import PsVectorType, deconstify from ..ast.expressions import ( - PsVectorArrayAccess, + PsVectorMemAcc, PsSymbolExpr, PsConstantExpr, PsBinOp, @@ -66,7 +66,7 @@ class MaterializeVectorIntrinsics: def visit(self, node: PsAstNode) -> PsAstNode: match node: - case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorArrayAccess): + case PsAssignment(lhs, rhs) if isinstance(lhs, PsVectorMemAcc): vc = VecTypeCtx() vc.set(lhs.get_vector_type()) store_arg = self.visit_expr(rhs, vc) @@ -94,7 +94,7 @@ class MaterializeVectorIntrinsics: else: return expr - case PsVectorArrayAccess(): + case PsVectorMemAcc(): vc.set(expr.get_vector_type()) return self._platform.vector_load(expr) diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index c7657ec51..52ded8ab2 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -12,7 +12,7 @@ from pystencils.types import PsIntegerType from pystencils.types.quick import Arr, SInt from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.field import Field, FieldType -from pystencils.backend.kernelfunction import FieldPointerParam +from pystencils.backend.properties import FieldBasePtr try: # noinspection PyPep8Naming @@ -244,9 +244,9 @@ class BoundaryHandling: for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items(): kwargs[self._field_name] = b[self._field_name] kwargs['indexField'] = idx_arr - data_used_in_kernel = (p.field.name + data_used_in_kernel = (p.fields.pop().name for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters - if isinstance(p, FieldPointerParam) and p.field.name not in kwargs) + if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in kwargs) kwargs.update({name: b[name] for name in data_used_in_kernel}) self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs) @@ -260,9 +260,9 @@ class BoundaryHandling: arguments = kwargs.copy() arguments[self._field_name] = b[self._field_name] arguments['indexField'] = idx_arr - data_used_in_kernel = (p.field.name + data_used_in_kernel = (p.fields.pop().name for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters - if isinstance(p, FieldPointerParam) and p.field.name not in arguments) + if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in arguments) arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments}) kernel = self._boundary_object_to_boundary_info[b_obj].kernel diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index ae64bdea3..7d9ac7aa4 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -20,8 +20,9 @@ from .backend.kernelcreation.iteration_space import ( from .backend.transformations import ( EliminateConstants, - EraseAnonymousStructTypes, + LowerToC, SelectFunctions, + CanonicalizeSymbols, ) from .backend.kernelfunction import ( create_cpu_kernel_function, @@ -131,7 +132,7 @@ def create_kernel( f"Code generation for target {target} not implemented" ) - # Simplifying transformations + # Fold and extract constants elim_constants = EliminateConstants(ctx, extract_constant_exprs=True) kernel_ast = cast(PsBlock, elim_constants(kernel_ast)) @@ -143,12 +144,23 @@ def create_kernel( kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) - erase_anons = EraseAnonymousStructTypes(ctx) - kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) + # Lowering + lower_to_c = LowerToC(ctx) + kernel_ast = cast(PsBlock, lower_to_c(kernel_ast)) select_functions = SelectFunctions(platform) kernel_ast = cast(PsBlock, select_functions(kernel_ast)) + # Late canonicalization and constant elimination passes + # * Since lowering introduces new index calculations and indexing symbols into the AST, + # * these need to be handled here + + canonicalize = CanonicalizeSymbols(ctx, True) + kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) + + late_fold_constants = EliminateConstants(ctx, extract_constant_exprs=False) + kernel_ast = cast(PsBlock, late_fold_constants(kernel_ast)) + if config.target.is_cpu(): return create_cpu_kernel_function( ctx, diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 5771eaca8..8e7d27f58 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -8,6 +8,7 @@ from .types import ( PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, + PsBoolType, ) UserTypeSpec = str | type | np.dtype | PsType @@ -143,6 +144,9 @@ def parse_type_string(s: str) -> PsType: def parse_type_name(typename: str, const: bool): match typename: + case "bool": + return PsBoolType(const=const) + case "int" | "int64" | "int64_t": return PsSignedIntegerType(64, const=const) case "int32" | "int32_t": diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index e6fc4bb78..d3d18720c 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -91,7 +91,8 @@ class PsPointerType(PsDereferencableType): def c_string(self) -> str: base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" - return f"{base_str} *{restrict_str} {self._const_string()}" + const_str = " const" if self.const else "" + return f"{base_str} *{restrict_str}{const_str}" def __repr__(self) -> str: return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )" diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py index 9701013b0..384fc9315 100644 --- a/tests/nbackend/kernelcreation/test_context.py +++ b/tests/nbackend/kernelcreation/test_context.py @@ -5,6 +5,8 @@ from pystencils import Field, TypedSymbol, FieldType, DynamicType from pystencils.backend.kernelcreation import KernelCreationContext from pystencils.backend.constants import PsConstant +from pystencils.backend.memory import PsSymbol +from pystencils.backend.properties import FieldShape, FieldStride from pystencils.backend.exceptions import KernelConstraintsError from pystencils.types.quick import SInt, Fp from pystencils.types import deconstify @@ -14,7 +16,7 @@ def test_field_arrays(): ctx = KernelCreationContext(index_dtype=SInt(16)) f = Field.create_generic("f", 3, Fp(32)) - f_arr = ctx.get_array(f) + f_arr = ctx.get_buffer(f) assert f_arr.element_type == f.dtype == Fp(32) assert len(f_arr.shape) == len(f.shape) + 1 == 4 @@ -23,9 +25,17 @@ def test_field_arrays(): assert f_arr.index_type == ctx.index_dtype == SInt(16) assert f_arr.shape[0].dtype == ctx.index_dtype == SInt(16) + for i, s in enumerate(f_arr.shape[:1]): + assert isinstance(s, PsSymbol) + assert FieldShape(f, i) in s.properties + + for i, s in enumerate(f_arr.strides[:1]): + assert isinstance(s, PsSymbol) + assert FieldStride(f, i) in s.properties + g = Field.create_generic("g", 3, index_shape=(2, 4), dtype=Fp(16)) - g_arr = ctx.get_array(g) - + g_arr = ctx.get_buffer(g) + assert g_arr.element_type == g.dtype == Fp(16) assert len(g_arr.shape) == len(g.spatial_shape) + len(g.index_shape) == 5 assert isinstance(g_arr.shape[3], PsConstant) and g_arr.shape[3].value == 2 @@ -39,26 +49,23 @@ def test_field_arrays(): FieldType.GENERIC, Fp(32), (0, 1), - ( - TypedSymbol("nx", SInt(32)), - TypedSymbol("ny", SInt(32)), - 1 - ), - ( - TypedSymbol("sx", SInt(32)), - TypedSymbol("sy", SInt(32)), - 1 - ) - ) - - h_arr = ctx.get_array(h) + (TypedSymbol("nx", SInt(32)), TypedSymbol("ny", SInt(32)), 1), + (TypedSymbol("sx", SInt(32)), TypedSymbol("sy", SInt(32)), 1), + ) + + h_arr = ctx.get_buffer(h) assert h_arr.index_type == SInt(32) - + for s in chain(h_arr.shape, h_arr.strides): assert deconstify(s.get_dtype()) == SInt(32) - assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == ["nx", "ny", "sx", "sy"] + assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == [ + "nx", + "ny", + "sx", + "sy", + ] def test_invalid_fields(): @@ -70,11 +77,11 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", SInt(32)),), - (TypedSymbol("sx", SInt(64)),) + (TypedSymbol("sx", SInt(64)),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) h = Field( "h", @@ -82,11 +89,11 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", Fp(32)),), - (TypedSymbol("sx", Fp(32)),) + (TypedSymbol("sx", Fp(32)),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) h = Field( "h", @@ -94,8 +101,39 @@ def test_invalid_fields(): Fp(32), (0,), (TypedSymbol("nx", DynamicType.NUMERIC_TYPE),), - (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),) + (TypedSymbol("sx", DynamicType.NUMERIC_TYPE),), ) - + with pytest.raises(KernelConstraintsError): - _ = ctx.get_array(h) + _ = ctx.get_buffer(h) + + +def test_duplicate_fields(): + f = Field.create_generic("f", 3) + g = f.new_field_with_different_name("g") + + # f and g have the same indexing symbols + assert f.shape == g.shape + assert f.strides == g.strides + + ctx = KernelCreationContext() + + f_buf = ctx.get_buffer(f) + g_buf = ctx.get_buffer(g) + + for sf, sg in zip(chain(f_buf.shape, f_buf.strides), chain(g_buf.shape, g_buf.strides)): + # Must be the same + assert sf == sg + + for i, s in enumerate(f_buf.shape[:-1]): + assert isinstance(s, PsSymbol) + assert FieldShape(f, i) in s.properties + assert FieldShape(g, i) in s.properties + + for i, s in enumerate(f_buf.strides[:-1]): + assert isinstance(s, PsSymbol) + assert FieldStride(f, i) in s.properties + assert FieldStride(g, i) in s.properties + + # Base pointers must be different, though! + assert f_buf.base_pointer != g_buf.base_pointer diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 341b75601..ce4f61785 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -17,7 +17,7 @@ from pystencils.backend.ast.structural import ( PsDeclaration, ) from pystencils.backend.ast.expressions import ( - PsArrayAccess, + PsBufferAcc, PsBitwiseAnd, PsBitwiseOr, PsBitwiseXor, @@ -106,22 +106,20 @@ def test_freeze_fields(): f, g = fields("f, g : [1D]") asm = Assignment(f.center(0), g.center(0)) - f_arr = ctx.get_array(f) - g_arr = ctx.get_array(g) + f_arr = ctx.get_buffer(f) + g_arr = ctx.get_buffer(g) fasm = freeze(asm) zero = PsExpression.make(PsConstant(0)) - lhs = PsArrayAccess( + lhs = PsBufferAcc( f_arr.base_pointer, - (PsExpression.make(counter) + zero) * PsExpression.make(f_arr.strides[0]) - + zero * one, + (PsExpression.make(counter) + zero, zero) ) - rhs = PsArrayAccess( + rhs = PsBufferAcc( g_arr.base_pointer, - (PsExpression.make(counter) + zero) * PsExpression.make(g_arr.strides[0]) - + zero * one, + (PsExpression.make(counter) + zero, zero) ) should = PsAssignment(lhs, rhs) diff --git a/tests/nbackend/kernelcreation/test_iteration_space.py b/tests/nbackend/kernelcreation/test_iteration_space.py index 8ff678fba..5d56abd2b 100644 --- a/tests/nbackend/kernelcreation/test_iteration_space.py +++ b/tests/nbackend/kernelcreation/test_iteration_space.py @@ -23,7 +23,7 @@ def test_slices_over_field(): islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) dims = ispace.dimensions @@ -58,7 +58,7 @@ def test_slices_with_fixed_size_field(): islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) dims = ispace.dimensions @@ -87,7 +87,7 @@ def test_singular_slice_over_field(): archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") ctx.add_field(archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) islice = (4, -3) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) @@ -113,7 +113,7 @@ def test_slices_with_negative_start(): archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx") ctx.add_field(archetype_field) - archetype_arr = ctx.get_array(archetype_field) + archetype_arr = ctx.get_buffer(archetype_field) islice = (slice(-3, -1, 1), slice(-4, None, 1)) ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 5ea2aa15e..988fa4bb8 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -371,6 +371,7 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,)) # Array type determined by default-typed symbol arr2 = sp.Symbol("arr2") @@ -378,6 +379,7 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3)) # Array type determined by pre-typed symbol q = TypedSymbol("q", Fp(16)) @@ -386,6 +388,14 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2)) + + # Array type determined by LHS symbol + arr4 = TypedSymbol("arr4", Arr(Int(16), 4)) + decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2))) + decl = typify(decl) + + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4) def test_erronous_typing(): diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index 02f03bfa9..2408b8d86 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -1,4 +1,8 @@ -from pystencils.backend.symbols import PsSymbol +import pytest + +from pystencils import create_type +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, Typifier +from pystencils.backend.memory import PsSymbol, BufferBasePtr from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( PsExpression, @@ -6,10 +10,13 @@ from pystencils.backend.ast.expressions import ( PsMemAcc, PsArrayInitList, PsSubscript, + PsBufferAcc, + PsSymbolExpr, ) from pystencils.backend.ast.structural import ( PsStatement, PsAssignment, + PsDeclaration, PsBlock, PsConditional, PsComment, @@ -20,15 +27,25 @@ from pystencils.types.quick import Fp, Ptr def test_cloning(): - x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y, z, m = [PsExpression.make(ctx.get_symbol(name)) for name in "xyzm"] + q = PsExpression.make(ctx.get_symbol("q", create_type("bool"))) + a, b, c = [PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc"] c1 = PsExpression.make(PsConstant(3.0)) c2 = PsExpression.make(PsConstant(-1.0)) - one = PsExpression.make(PsConstant(1)) + one_f = PsExpression.make(PsConstant(1.0)) + one_i = PsExpression.make(PsConstant(1)) def check(orig, clone): assert not (orig is clone) assert type(orig) is type(clone) assert orig.structurally_equal(clone) + + if isinstance(orig, PsExpression): + # Regression: Expression data types used to not be cloned + assert orig.dtype == clone.dtype for c1, c2 in zip(orig.children, clone.children, strict=True): check(c1, c2) @@ -44,18 +61,21 @@ def test_cloning(): PsAssignment(y, x / c1), PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]), PsConditional( - y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + q, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + ), + PsDeclaration( + m, + PsArrayInitList([ + [x, y, one_f + x], + [one_f, c2, z] + ]) ), - PsArrayInitList([ - [x, y, one + x], - [one, c2, z] - ]), PsPragma("omp parallel for"), PsLoop( - x, - y, - z, - one, + a, + b, + c, + one_i, PsBlock( [ PsComment("Loop body"), @@ -63,12 +83,55 @@ def test_cloning(): PsAssignment(x, y), PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( - PsMemAcc(PsCast(Ptr(Fp(32)), z), one) - + PsSubscript(z, (one + one + one, y + one)) + PsMemAcc(PsCast(Ptr(Fp(32)), z), one_i) + + PsCast(Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i))) ), ] ), ), ]: + ast = typify(ast) ast_clone = ast.clone() check(ast, ast_clone) + + +def test_buffer_acc(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + from pystencils import fields + + f, g = fields("f, g(3): [2D]") + a, b = [ctx.get_symbol(n, ctx.index_dtype) for n in "ab"] + + f_buf = ctx.get_buffer(f) + + f_acc = PsBufferAcc(f_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(0)]) + assert f_acc.buffer == f_buf + assert f_acc.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) + + f_acc_clone = f_acc.clone() + assert f_acc_clone is not f_acc + + assert f_acc_clone.buffer == f_buf + assert f_acc_clone.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer)) + assert len(f_acc_clone.index) == 3 + assert f_acc_clone.index[0].structurally_equal(PsSymbolExpr(ctx.get_symbol("a"))) + assert f_acc_clone.index[1].structurally_equal(PsSymbolExpr(ctx.get_symbol("b"))) + + g_buf = ctx.get_buffer(g) + + g_acc = PsBufferAcc(g_buf.base_pointer, [PsExpression.make(i) for i in (a, b)] + [factory.parse_index(2)]) + assert g_acc.buffer == g_buf + assert g_acc.base_pointer.structurally_equal(PsSymbolExpr(g_buf.base_pointer)) + + second_bptr = PsExpression.make(ctx.get_symbol("data_g_interior", g_buf.base_pointer.dtype)) + second_bptr.symbol.add_property(BufferBasePtr(g_buf)) + g_acc.base_pointer = second_bptr + + assert g_acc.base_pointer == second_bptr + assert g_acc.buffer == g_buf + + # cannot change base pointer to different buffer + with pytest.raises(ValueError): + g_acc.base_pointer = PsExpression.make(f_buf.base_pointer) diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index dc7a86b0b..ef4806314 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -3,10 +3,9 @@ from pystencils import Target from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.kernelfunction import KernelFunction -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant from pystencils.backend.literals import PsLiteral -from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index b621829ad..648112ef9 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -3,11 +3,10 @@ import pytest from pystencils import Target # from pystencils.backend.constraints import PsKernelParamsConstraint -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol, PsBuffer from pystencils.backend.constants import PsConstant -from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer -from pystencils.backend.ast.expressions import PsArrayAccess, PsExpression +from pystencils.backend.ast.expressions import PsBufferAcc, PsExpression from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop from pystencils.backend.kernelfunction import KernelFunction @@ -21,8 +20,8 @@ import numpy as np def test_pairwise_addition(): idx_type = SInt(64) - u = PsLinearizedArray("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) - v = PsLinearizedArray("v", Fp(64), (...,), (...,), index_dtype=idx_type) + u = PsBuffer("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type) + v = PsBuffer("v", Fp(64), (...,), (...,), index_dtype=idx_type) u_data = PsArrayBasePointer("u_data", u) v_data = PsArrayBasePointer("v_data", v) @@ -34,8 +33,8 @@ def test_pairwise_addition(): two = PsExpression.make(PsConstant(2, idx_type)) update = PsAssignment( - PsArrayAccess(v_data, loop_ctr), - PsArrayAccess(u_data, two * loop_ctr) + PsArrayAccess(u_data, two * loop_ctr + one) + PsBufferAcc(v_data, loop_ctr), + PsBufferAcc(u_data, two * loop_ctr) + PsBufferAcc(u_data, two * loop_ctr + one) ) loop = PsLoop( diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py index 914d05594..b1403185c 100644 --- a/tests/nbackend/test_extensions.py +++ b/tests/nbackend/test_extensions.py @@ -3,7 +3,7 @@ import sympy as sp from pystencils import make_slice, Field, Assignment from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace -from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations +from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC from pystencils.backend.literals import PsLiteral from pystencils.backend.emission import CAstPrinter from pystencils.backend.ast.expressions import PsExpression, PsSubscript @@ -46,6 +46,9 @@ def test_literals(): hoist = HoistLoopInvariantDeclarations(ctx) ast = hoist(ast) + lower = LowerToC(ctx) + ast = lower(ast) + assert isinstance(ast, PsBlock) assert len(ast.statements) == 2 assert ast.statements[0] == x_decl diff --git a/tests/nbackend/test_memory.py b/tests/nbackend/test_memory.py new file mode 100644 index 000000000..5841e0f4f --- /dev/null +++ b/tests/nbackend/test_memory.py @@ -0,0 +1,50 @@ +import pytest + +from dataclasses import dataclass +from pystencils.backend.memory import PsSymbol, PsSymbolProperty, UniqueSymbolProperty + + +def test_properties(): + @dataclass(frozen=True) + class NumbersProperty(PsSymbolProperty): + n: int + x: float + + @dataclass(frozen=True) + class StringProperty(PsSymbolProperty): + s: str + + @dataclass(frozen=True) + class MyUniqueProperty(UniqueSymbolProperty): + val: int + + s = PsSymbol("s") + + assert not s.properties + + s.add_property(NumbersProperty(42, 8.71)) + assert s.properties == {NumbersProperty(42, 8.71)} + + # no duplicates + s.add_property(NumbersProperty(42, 8.71)) + assert s.properties == {NumbersProperty(42, 8.71)} + + s.add_property(StringProperty("pystencils")) + assert s.properties == {NumbersProperty(42, 8.71), StringProperty("pystencils")} + + assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)} + + assert not s.get_properties(MyUniqueProperty) + + s.add_property(MyUniqueProperty(13)) + assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)} + + # Adding the same one again does not raise + s.add_property(MyUniqueProperty(13)) + assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)} + + with pytest.raises(ValueError): + s.add_property(MyUniqueProperty(14)) + + s.remove_property(MyUniqueProperty(13)) + assert not s.get_properties(MyUniqueProperty) diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py index 6a4785564..a11e9bd13 100644 --- a/tests/nbackend/transformations/test_canonicalize_symbols.py +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -52,7 +52,7 @@ def test_deduplication(): assert canonicalize.get_last_live_symbols() == { ctx.find_symbol("y"), ctx.find_symbol("z"), - ctx.get_array(f).base_pointer, + ctx.get_buffer(f).base_pointer, } assert ctx.find_symbol("x") is not None diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 92bb5c947..4c1897008 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,6 +1,6 @@ from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr -from pystencils.backend.symbols import PsSymbol +from pystencils.backend.memory import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateConstants diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py new file mode 100644 index 000000000..b557a7493 --- /dev/null +++ b/tests/nbackend/transformations/test_lower_to_c.py @@ -0,0 +1,122 @@ +from functools import reduce +from operator import add + +from pystencils import fields, Assignment, make_slice, Field, FieldType +from pystencils.types import PsStructType, create_type + +from pystencils.backend.memory import BufferBasePtr +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import LowerToC + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.expressions import ( + PsBufferAcc, + PsMemAcc, + PsSymbolExpr, + PsExpression, + PsLookup, + PsAddressOf, + PsCast, +) +from pystencils.backend.ast.structural import PsAssignment + + +def test_lower_buffer_accesses(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:42, :31]) + ctx.set_iteration_space(ispace) + + lower = LowerToC(ctx) + + f, g = fields("f(2), g(3): [2D]") + asm = Assignment(f.center(1), g[-1, 1](2)) + + f_buf = ctx.get_buffer(f) + g_buf = ctx.get_buffer(g) + + fasm = factory.parse_sympy(asm) + assert isinstance(fasm.lhs, PsBufferAcc) + assert isinstance(fasm.rhs, PsBufferAcc) + + fasm_lowered = lower(fasm) + assert isinstance(fasm_lowered, PsAssignment) + + assert isinstance(fasm_lowered.lhs, PsMemAcc) + assert isinstance(fasm_lowered.lhs.pointer, PsSymbolExpr) + assert fasm_lowered.lhs.pointer.symbol == f_buf.base_pointer + + zero = factory.parse_index(0) + expected_offset = reduce( + add, + ( + (PsExpression.make(dm.counter) + zero) * PsExpression.make(stride) + for dm, stride in zip(ispace.dimensions, f_buf.strides) + ), + ) + factory.parse_index(1) * PsExpression.make(f_buf.strides[-1]) + assert fasm_lowered.lhs.offset.structurally_equal(expected_offset) + + assert isinstance(fasm_lowered.rhs, PsMemAcc) + assert isinstance(fasm_lowered.rhs.pointer, PsSymbolExpr) + assert fasm_lowered.rhs.pointer.symbol == g_buf.base_pointer + + expected_offset = ( + (PsExpression.make(ispace.dimensions[0].counter) + factory.parse_index(-1)) + * PsExpression.make(g_buf.strides[0]) + + (PsExpression.make(ispace.dimensions[1].counter) + factory.parse_index(1)) + * PsExpression.make(g_buf.strides[1]) + + factory.parse_index(2) * PsExpression.make(g_buf.strides[-1]) + ) + assert fasm_lowered.rhs.offset.structurally_equal(expected_offset) + + +def test_lower_anonymous_structs(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:12]) + ctx.set_iteration_space(ispace) + + lower = LowerToC(ctx) + + stype = PsStructType( + [ + ("val", ctx.default_dtype), + ("x", ctx.index_dtype), + ] + ) + sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype) + f = Field.create_generic("f", 1, ctx.default_dtype, field_type=FieldType.CUSTOM) + + asm = Assignment(sfield.center("val"), f.absolute_access((sfield.center("x"),), (0,))) + + fasm = factory.parse_sympy(asm) + + sbuf = ctx.get_buffer(sfield) + + assert isinstance(fasm, PsAssignment) + assert isinstance(fasm.lhs, PsLookup) + + lowered_fasm = lower(fasm.clone()) + assert isinstance(lowered_fasm, PsAssignment) + + # Check type of sfield data pointer + for expr in dfs_preorder(lowered_fasm, lambda n: isinstance(n, PsSymbolExpr)): + if expr.symbol.name == sbuf.base_pointer.name: + assert expr.symbol.dtype == create_type("uint8_t * restrict") + + # Check LHS + assert isinstance(lowered_fasm.lhs, PsMemAcc) + assert isinstance(lowered_fasm.lhs.pointer, PsCast) + assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf) + assert isinstance(lowered_fasm.lhs.pointer.operand.operand, PsMemAcc) + type_erased_pointer = lowered_fasm.lhs.pointer.operand.operand.pointer + + assert isinstance(type_erased_pointer, PsSymbolExpr) + assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties + assert type_erased_pointer.symbol.dtype == create_type("uint8_t * restrict") -- GitLab