From cd599101e09cd1ae253eb5a027bb91bb72f74a03 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 17 Jan 2024 16:09:27 +0100 Subject: [PATCH] first sketch of translation context and iteration domain --- src/pystencils/nbackend/arrays.py | 150 ++++++++---------- src/pystencils/nbackend/ast/kernelfunction.py | 8 +- .../nbackend/{ast => }/constraints.py | 4 +- .../nbackend/jit/cpu_extension_module.py | 4 +- .../nbackend/translation/__init__.py | 0 .../nbackend/translation/context.py | 24 ++- .../nbackend/translation/field_array_pair.py | 21 +++ .../nbackend/translation/iteration_domain.py | 130 +++++++++++++++ src/pystencils/nbackend/typed_expressions.py | 3 +- tests/nbackend/test_basic_printing.py | 4 +- tests/nbackend/test_cpujit.py | 8 +- tests/nbackend/test_expressions.py | 16 +- 12 files changed, 267 insertions(+), 105 deletions(-) rename src/pystencils/nbackend/{ast => }/constraints.py (88%) create mode 100644 src/pystencils/nbackend/translation/__init__.py create mode 100644 src/pystencils/nbackend/translation/field_array_pair.py create mode 100644 src/pystencils/nbackend/translation/iteration_domain.py diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py index fd416a20c..e206b28fd 100644 --- a/src/pystencils/nbackend/arrays.py +++ b/src/pystencils/nbackend/arrays.py @@ -31,8 +31,8 @@ all occurences of the shape and stride variables with their constant value: ``` constraints = ( - [PsParamConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] - + [PsParamConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)] + [PsKernelConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] + + [PsKernelConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)] ) kernel_function.add_constraints(*constraints) @@ -43,6 +43,8 @@ kernel_function.add_constraints(*constraints) from __future__ import annotations +from types import EllipsisType + from abc import ABC import pymbolic.primitives as pb @@ -56,78 +58,94 @@ from .types import ( constify, ) -from .typed_expressions import PsTypedVariable, ExprOrConstant +from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant class PsLinearizedArray: - """N-dimensional contiguous array""" + """Class to model N-dimensional contiguous arrays. + + Memory Layout, Shape and Strides + -------------------------------- + + The memory layout of an array is defined by its shape and strides. + Both shape and stride entries may either be constants or special variables associated with + exactly one array. + + Shape and strides may be specified at construction in the following way. + For constant entries, their value must be given as an integer. + For variable shape entries and strides, the Ellipsis `...` must be passed instead. + Internally, the passed `index_dtype` will be used to create typed constants (`PsTypedConstant`) + and variables (`PsArrayShapeVar` and `PsArrayStrideVar`) from the passed values. + """ def __init__( self, name: str, - element_type: PsScalarType, - dim: int, + element_type: PsAbstractType, + shape: tuple[int | EllipsisType, ...], + strides: tuple[int | EllipsisType, ...], index_dtype: PsIntegerType = PsSignedIntegerType(64), ): self._name = name + self._element_type = element_type + self._index_dtype = index_dtype - self._shape = tuple( - PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim) - ) - self._strides = tuple( - PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim) + if len(shape) != len(strides): + raise ValueError("Shape and stride tuples must have the same length") + + self._shape: tuple[PsArrayShapeVar | PsTypedConstant, ...] = tuple( + ( + PsArrayShapeVar(self, i, index_dtype) + if s == Ellipsis + else PsTypedConstant(s, index_dtype) + ) + for i, s in enumerate(shape) ) - self._element_type = element_type - self._dim = dim - self._index_dtype = index_dtype + self._strides: tuple[PsArrayStrideVar | PsTypedConstant, ...] = tuple( + ( + PsArrayStrideVar(self, i, index_dtype) + if s == Ellipsis + else PsTypedConstant(s, index_dtype) + ) + for i, s in enumerate(strides) + ) @property def name(self): return self._name @property - def shape(self): + def shape(self) -> tuple[PsArrayShapeVar | PsTypedConstant, ...]: return self._shape @property - def strides(self): + def strides(self) -> tuple[PsArrayStrideVar | PsTypedConstant, ...]: return self._strides - @property - def dim(self): - return self._dim - @property def element_type(self): return self._element_type + + def _hashable_contents(self): + """Contents by which to compare two instances of `PsLinearizedArray`. + + Since equality checks on shape and stride variables internally check equality of their associated arrays, + if these variables would occur in here, an infinite recursion would follow. + Hence they are filtered and replaced by the ellipsis. + """ + shape_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._shape) + strides_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._strides) + return (self._name, self._element_type, self._index_dtype, shape_clean, strides_clean) def __eq__(self, other: object) -> bool: if not isinstance(other, PsLinearizedArray): return False - return ( - self._name, - self._element_type, - self._dim, - self._index_dtype, - ) == ( - other._name, - other._element_type, - other._dim, - other._index_dtype, - ) + return self._hashable_contents() == other._hashable_contents() def __hash__(self) -> int: - return hash( - ( - self._name, - self._element_type, - self._dim, - self._index_dtype, - ) - ) - + return hash(self._hashable_contents()) class PsArrayAssocVar(PsTypedVariable, ABC): """A variable that is associated to an array. @@ -166,6 +184,11 @@ class PsArrayBasePointer(PsArrayAssocVar): class PsArrayShapeVar(PsArrayAssocVar): + """Variable that represents an array's shape in one coordinate. + + Do not instantiate this class yourself, but only use its instances + as provided by `PsLinearizedArray.shape`. + """ init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype") @@ -183,6 +206,11 @@ class PsArrayShapeVar(PsArrayAssocVar): class PsArrayStrideVar(PsArrayAssocVar): + """Variable that represents an array's stride in one coordinate. + + Do not instantiate this class yourself, but only use its instances + as provided by `PsLinearizedArray.strides`. + """ init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") __match_args__ = ("array", "coordinate", "dtype") @@ -217,45 +245,3 @@ class PsArrayAccess(pb.Subscript): def dtype(self) -> PsAbstractType: """Data type of this expression, i.e. the element type of the underlying array""" return self._base_ptr.array.element_type - - -# class PsIterationDomain: -# """A factory for arrays spanning a given iteration domain.""" - -# def __init__( -# self, -# id: str, -# dim: int | None = None, -# fixed_shape: tuple[int, ...] | None = None, -# index_dtype: PsIntegerType = PsSignedIntegerType(64), -# ): -# if fixed_shape is not None: -# if dim is not None and len(fixed_shape) != dim: -# raise ValueError( -# "If both `dim` and `fixed_shape` are specified, `fixed_shape` must have exactly `dim` entries." -# ) - -# shape = tuple(PsTypedConstant(s, index_dtype) for s in fixed_shape) -# elif dim is not None: -# shape = tuple( -# PsTypedVariable(f"{id}_shape_{d}", index_dtype) for d in range(dim) -# ) -# else: -# raise ValueError("Either `fixed_shape` or `dim` must be specified.") - -# self._domain_shape: tuple[VarOrConstant, ...] = shape -# self._index_dtype = index_dtype - -# self._archetype_array: PsLinearizedArray | None = None - -# self._constraints: list[PsParamConstraint] = [] - -# @property -# def dim(self) -> int: -# return len(self._domain_shape) - -# @property -# def shape(self) -> tuple[VarOrConstant, ...]: -# return self._domain_shape - -# def create_array(self, ghost_layers: int = 0): diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py index 060947a82..9aecb16ea 100644 --- a/src/pystencils/nbackend/ast/kernelfunction.py +++ b/src/pystencils/nbackend/ast/kernelfunction.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from pymbolic.mapper.dependency import DependencyMapper from .nodes import PsAstNode, PsBlock, failing_cast -from .constraints import PsParamConstraint +from ..constraints import PsKernelConstraint from ..typed_expressions import PsTypedVariable from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar from ..exceptions import PsInternalCompilerError @@ -26,7 +26,7 @@ class PsKernelParametersSpec: params: tuple[PsTypedVariable, ...] arrays: tuple[PsLinearizedArray, ...] - constraints: tuple[PsParamConstraint, ...] + constraints: tuple[PsKernelConstraint, ...] def params_for_array(self, arr: PsLinearizedArray): def pred(p: PsTypedVariable): @@ -71,7 +71,7 @@ class PsKernelFunction(PsAstNode): self._target = target self._name = name - self._constraints: list[PsParamConstraint] = [] + self._constraints: list[PsKernelConstraint] = [] @property def target(self) -> Target: @@ -120,7 +120,7 @@ class PsKernelFunction(PsAstNode): raise IndexError(f"Child index out of bounds: {idx}") self._body = failing_cast(PsBlock, c) - def add_constraints(self, *constraints: PsParamConstraint): + def add_constraints(self, *constraints: PsKernelConstraint): self._constraints += constraints def get_parameters(self) -> PsKernelParametersSpec: diff --git a/src/pystencils/nbackend/ast/constraints.py b/src/pystencils/nbackend/constraints.py similarity index 88% rename from src/pystencils/nbackend/ast/constraints.py rename to src/pystencils/nbackend/constraints.py index d11fe1195..0cda3f4dc 100644 --- a/src/pystencils/nbackend/ast/constraints.py +++ b/src/pystencils/nbackend/constraints.py @@ -4,11 +4,11 @@ import pymbolic.primitives as pb from pymbolic.mapper.c_code import CCodeMapper from pymbolic.mapper.dependency import DependencyMapper -from ..typed_expressions import PsTypedVariable +from .typed_expressions import PsTypedVariable @dataclass -class PsParamConstraint: +class PsKernelConstraint: condition: pb.Comparison message: str = "" diff --git a/src/pystencils/nbackend/jit/cpu_extension_module.py b/src/pystencils/nbackend/jit/cpu_extension_module.py index f07172e3d..3b67fa45f 100644 --- a/src/pystencils/nbackend/jit/cpu_extension_module.py +++ b/src/pystencils/nbackend/jit/cpu_extension_module.py @@ -11,7 +11,7 @@ import numpy as np from ..exceptions import PsInternalCompilerError from ..ast import PsKernelFunction -from ..ast.constraints import PsParamConstraint +from ..constraints import PsKernelConstraint from ..typed_expressions import PsTypedVariable from ..arrays import ( PsLinearizedArray, @@ -285,7 +285,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ case _: assert False, "Invalid variable encountered." - def check_constraint(self, constraint: PsParamConstraint): + def check_constraint(self, constraint: PsKernelConstraint): variables = constraint.get_variables() for var in variables: diff --git a/src/pystencils/nbackend/translation/__init__.py b/src/pystencils/nbackend/translation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pystencils/nbackend/translation/context.py b/src/pystencils/nbackend/translation/context.py index 199315579..cc9cfc0fc 100644 --- a/src/pystencils/nbackend/translation/context.py +++ b/src/pystencils/nbackend/translation/context.py @@ -1,3 +1,9 @@ +from ...field import Field +from ..arrays import PsLinearizedArray, PsArrayBasePointer +from ..types import PsIntegerType +from ..constraints import PsKernelConstraint + +from .iteration_domain import PsIterationDomain class PsTranslationContext: """The `PsTranslationContext` manages the translation process from the SymPy frontend @@ -27,7 +33,6 @@ class PsTranslationContext: Domain fields can only be accessed by relative offsets, and therefore must always be associated with an *iteration domain* that provides a spatial index tuple. All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers. - A field and its array may be associated with multiple iteration domains. - `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index. If there is at least one indexed field present there must also exist an index source for that field (loop or device indexing). @@ -36,6 +41,21 @@ class PsTranslationContext: Within a domain, a buffer may be either written to or read from, never both. + In the translator, frontend fields and backend arrays are managed together using the `PsFieldArrayPair` class. + """ + + def __init__(self, index_dtype: PsIntegerType): + self._index_dtype = index_dtype + self._constraints: list[PsKernelConstraint] = [] + + @property + def index_dtype(self) -> PsIntegerType: + return self._index_dtype + def add_constraints(self, *constraints: PsKernelConstraint): + self._constraints += constraints + + @property + def constraints(self) -> tuple[PsKernelConstraint, ...]: + return tuple(self._constraints) - """ diff --git a/src/pystencils/nbackend/translation/field_array_pair.py b/src/pystencils/nbackend/translation/field_array_pair.py new file mode 100644 index 000000000..720b5c1c7 --- /dev/null +++ b/src/pystencils/nbackend/translation/field_array_pair.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from ...field import Field +from ..arrays import PsLinearizedArray, PsArrayBasePointer +from ..types import PsIntegerType +from ..constraints import PsKernelConstraint + +from .iteration_domain import PsIterationDomain + +@dataclass +class PsFieldArrayPair: + field: Field + array: PsLinearizedArray + base_ptr: PsArrayBasePointer + + +@dataclass +class PsDomainFieldArrayPair(PsFieldArrayPair): + ghost_layers: int + interior_base_ptr: PsArrayBasePointer + domain: PsIterationDomain diff --git a/src/pystencils/nbackend/translation/iteration_domain.py b/src/pystencils/nbackend/translation/iteration_domain.py new file mode 100644 index 000000000..990a4ff67 --- /dev/null +++ b/src/pystencils/nbackend/translation/iteration_domain.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast +from types import EllipsisType + +from ...field import Field +from ...typing import TypedSymbol, BasicType +from ..arrays import PsLinearizedArray, PsArrayBasePointer +from ..types.quick import make_type +from ..typed_expressions import PsTypedVariable, PsTypedConstant, VarOrConstant +from .field_array_pair import PsDomainFieldArrayPair + +if TYPE_CHECKING: + from .context import PsTranslationContext + +class PsIterationDomain: + """Represents the n-dimensonal spatial iteration domain of a pystencils kernel. + + Domain Shape + ------------ + + A domain may have either constant or variable, n-dimensional shape, where n = 1, 2, 3. + If the shape is variable, the domain object manages variables for each shape entry. + + The domain provides index variables for each dimension which may be used to access fields + associated with the domain. + In the kernel, these index variables must be provided by some index source. + Index sources differ between two major types of domains: full and sparse domains. + + In a full domain, it is guaranteed that each interior point is processed by the kernel. + The index source may therefore be a full n-fold loop nest, or a device index calculation. + + In a sparse domain, the iteration is controlled by an index vector, which acts as the index + source. + + Arrays + ------ + + Any number of domain arrays may be associated with each domain. + Each array is annotated with a number of ghost layers for each spatial coordinate. + + ### Shape Compatibility + + When an array is associated with a domain, it must be ensured that the array's shape + is compatible with the domain. + The first n shape entries are considered the array's spatial shape. + These spatial shapes, after subtracting ghost layers, must all be equal, and are further + constrained by a constant domain shape. + For each spatial coordinate, shape compatibility is ensured as described by the following table. + + | | Constant Array Shape | Variable Array Shape | + |---------------------------|-----------------------------|------------------------| + | **Constant Domain Shape** | Compile-Time Equality Check | Kernel Constraints | + | **Variable Domain Shape** | Invalid, Compiler Error | Kernel Constraints | + + ### Base Pointers and Array Accesses + + In the kernel's public interface, each array is represented at least through its base pointer, + which represents the starting address of the array's data in memory. + Since the iteration domain models arrays as being surrounded by ghost layers, it provides for each + array a second, *interior* base pointer, which points to the first interior point after skipping the + ghost layers, e.g. in three dimensions with one index dimension: + + ``` + addr(interior_base_ptr[0, 0, 0, 0]) == addr(base_ptr[gls, gls, gls, 0]) + ``` + + To access domain arrays using the domain's index variables, the interior base pointer should be used, + since the domain index variables always count up from zero. + + """ + + def __init__(self, ctx: PsTranslationContext, shape: tuple[int | EllipsisType, ...]): + self._ctx = ctx + + if len(shape) == 0: + raise ValueError("Domain shape must be at least one-dimensional.") + + if len(shape) > 3: + raise ValueError("Iteration domain can be at most three-dimensional.") + + self._shape: tuple[VarOrConstant, ...] = tuple( + ( + PsTypedVariable(f"domain_size_{i}", self._ctx.index_dtype) + if s == Ellipsis + else PsTypedConstant(s, self._ctx.index_dtype) + ) + for i, s in enumerate(shape) + ) + + self._archetype_field: PsDomainFieldArrayPair | None = None + self._fields: dict[str, PsDomainFieldArrayPair] = dict() + + @property + def shape(self) -> tuple[VarOrConstant, ...]: + return self._shape + + def add_field(self, field: Field, ghost_layers: int) -> PsDomainFieldArrayPair: + arr_shape = tuple( + (Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis + for s in field.shape + ) + + arr_strides = tuple( + (Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis + for s in field.strides + ) + + # TODO: frontend should use new type system + element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type) + + arr = PsLinearizedArray(field.name, element_type, arr_shape, arr_strides, self._ctx.index_dtype) + + fa_pair = PsDomainFieldArrayPair( + field=field, + array=arr, + base_ptr=PsArrayBasePointer("arr_data", arr), + ghost_layers=ghost_layers, + interior_base_ptr=PsArrayBasePointer("arr_interior_data", arr), + domain=self + ) + + # Check shape compatibility + # TODO + for domain_s, field_s in zip(self.shape, field.shape): + if isinstance(domain_s, PsTypedConstant): + pass + + raise NotImplementedError() + diff --git a/src/pystencils/nbackend/typed_expressions.py b/src/pystencils/nbackend/typed_expressions.py index 5bfd0fcb1..b33114426 100644 --- a/src/pystencils/nbackend/typed_expressions.py +++ b/src/pystencils/nbackend/typed_expressions.py @@ -206,8 +206,7 @@ class PsTypedConstant: return PsTypedConstant(rem, self._dtype) def __neg__(self): - minus_one = PsTypedConstant(-1, self._dtype) - return pb.Product((minus_one, self)) + return PsTypedConstant(- self._value, self._dtype) def __bool__(self): return bool(self._value) diff --git a/tests/nbackend/test_basic_printing.py b/tests/nbackend/test_basic_printing.py index ba2f7770d..8d9fc6483 100644 --- a/tests/nbackend/test_basic_printing.py +++ b/tests/nbackend/test_basic_printing.py @@ -10,7 +10,7 @@ from pystencils.nbackend.emission import CPrinter def test_basic_kernel(): - u_arr = PsLinearizedArray("u", Fp(64), 1) + u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, )) u_size = u_arr.shape[0] u_base = PsArrayBasePointer("u_data", u_arr) @@ -40,5 +40,5 @@ def test_basic_kernel(): assert code.find("(" + params_str + ")") >= 0 - assert code.find("u_data[ctr] = u_data[ctr - 1] + u_data[ctr + 1];") >= 0 + assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr + -1];") >= 0 diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index b77ad6fff..6c2a453c7 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -3,7 +3,7 @@ import pytest from pystencils import Target from pystencils.nbackend.ast import * -from pystencils.nbackend.ast.constraints import PsParamConstraint +from pystencils.nbackend.constraints import PsKernelConstraint from pystencils.nbackend.typed_expressions import * from pystencils.nbackend.arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAccess from pystencils.nbackend.types.quick import * @@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load def test_pairwise_addition(): idx_type = SInt(64) - u = PsLinearizedArray("u", Fp(64, const=True), 2, index_dtype=idx_type) - v = PsLinearizedArray("v", Fp(64), 2, index_dtype=idx_type) + u = PsLinearizedArray("u", Fp(64, const=True), (..., ...), (..., ...), index_dtype=idx_type) + v = PsLinearizedArray("v", Fp(64), (..., ...), (..., ...), index_dtype=idx_type) u_data = PsArrayBasePointer("u_data", u) v_data = PsArrayBasePointer("v_data", v) @@ -42,7 +42,7 @@ def test_pairwise_addition(): func = PsKernelFunction(PsBlock([loop]), target=Target.CPU) - sizes_constraint = PsParamConstraint( + sizes_constraint = PsKernelConstraint( u.shape[0].eq(2 * v.shape[0]), "Array `u` must have twice the length of array `v`" ) diff --git a/tests/nbackend/test_expressions.py b/tests/nbackend/test_expressions.py index b3485b267..6c24a6442 100644 --- a/tests/nbackend/test_expressions.py +++ b/tests/nbackend/test_expressions.py @@ -8,15 +8,18 @@ def test_variable_equality(): var2 = PsTypedVariable("x", Fp(32)) assert var1 == var2 - arr = PsLinearizedArray("arr", Fp(64), 3) + shape = (..., ..., ...) + strides = (..., ..., ...) + + arr = PsLinearizedArray("arr", Fp(64), shape, strides) bp1 = PsArrayBasePointer("arr_data", arr) bp2 = PsArrayBasePointer("arr_data", arr) assert bp1 == bp2 - arr1 = PsLinearizedArray("arr", Fp(64), 3) + arr1 = PsLinearizedArray("arr", Fp(64), shape, strides) bp1 = PsArrayBasePointer("arr_data", arr1) - arr2 = PsLinearizedArray("arr", Fp(64), 3) + arr2 = PsLinearizedArray("arr", Fp(64), shape, strides) bp2 = PsArrayBasePointer("arr_data", arr2) assert bp1 == bp2 @@ -28,6 +31,9 @@ def test_variable_equality(): def test_variable_inequality(): + shape = (..., ..., ...) + strides = (..., ..., ...) + var1 = PsTypedVariable("x", Fp(32)) var2 = PsTypedVariable("x", Fp(64)) assert var1 != var2 @@ -37,10 +43,10 @@ def test_variable_inequality(): assert var1 != var2 # Arrays - arr1 = PsLinearizedArray("arr", Fp(64), 3) + arr1 = PsLinearizedArray("arr", Fp(64), shape, strides) bp1 = PsArrayBasePointer("arr_data", arr1) - arr2 = PsLinearizedArray("arr", Fp(32), 3) + arr2 = PsLinearizedArray("arr", Fp(32), shape, strides) bp2 = PsArrayBasePointer("arr_data", arr2) assert bp1 != bp2 -- GitLab