From ec0f31e8be3eb5eeadbd75ea1d474e792ca0bee5 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 16 Jan 2024 13:23:10 +0100
Subject: [PATCH] array refactoring, constraints

---
 pystencils/nbackend/arrays.py             | 209 ++++++++++++++++++++++
 pystencils/nbackend/ast/constraints.py    |  13 ++
 pystencils/nbackend/ast/kernelfunction.py |  73 +++++++-
 pystencils/nbackend/ast/nodes.py          |   9 +-
 pystencils/nbackend/c_printer.py          |   4 +-
 pystencils/nbackend/sympy_mapper.py       |   7 +-
 pystencils/nbackend/typed_expressions.py  |  96 +---------
 7 files changed, 307 insertions(+), 104 deletions(-)
 create mode 100644 pystencils/nbackend/arrays.py
 create mode 100644 pystencils/nbackend/ast/constraints.py

diff --git a/pystencils/nbackend/arrays.py b/pystencils/nbackend/arrays.py
new file mode 100644
index 000000000..23f0ff778
--- /dev/null
+++ b/pystencils/nbackend/arrays.py
@@ -0,0 +1,209 @@
+"""
+Arrays
+======
+
+The pystencils backend models contiguous n-dimensional arrays using a number of classes.
+Arrays themselves are represented through the `PsLinearizedArray` class.
+An array has a fixed name, dimensionality, and element type, as well as a number of associated
+variables.
+
+The associated variables are the *shape* and *strides* of the array, modelled by the
+`PsArrayShapeVar` and `PsArrayStrideVar` classes. They have integer type and are used to
+reason about the array's memory layout.
+
+
+Memory Layout Constraints
+-------------------------
+
+Initially, all memory layout information about an array is symbolic and unconstrained.
+Several scenarios exist where memory layout must be constrained, e.g. certain pointers
+need to be aligned, certain strides must be fixed or fulfill certain alignment properties,
+or even the field shape must be fixed.
+
+The code generation backend models such requirements and assumptions as *constraints*.
+Constraints are external to the arrays themselves. They are created by the AST passes which
+require them and exposed through the `PsKernelFunction` class to the compiler kernel's runtime
+environment. It is the responsibility of the runtime environment to fulfill all constraints.
+
+For example, if an array `arr` should have both a fixed shape and fixed strides,
+an optimization pass will have to add equality constraints like the following before replacing
+all occurences of the shape and stride variables with their constant value:
+
+```
+constraints = (
+    [PsParamConstraint(s.eq(f)) for s, f in zip(arr.shape, fixed_size)] 
+    + [PsParamConstraint(s.eq(f)) for s, f in zip(arr.strides, fixed_strides)]
+)
+
+kernel_function.add_constraints(*constraints)
+```
+
+"""
+
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from abc import ABC
+
+import pymbolic.primitives as pb
+
+from .types import (
+    PsAbstractType,
+    PsScalarType,
+    PsPointerType,
+    PsIntegerType,
+    PsSignedIntegerType,
+    constify,
+)
+
+if TYPE_CHECKING:
+    from .typed_expressions import PsTypedVariable, PsTypedConstant
+
+
+class PsLinearizedArray:
+    """N-dimensional contiguous array"""
+
+    def __init__(
+        self,
+        name: str,
+        element_type: PsScalarType,
+        dim: int,
+        offsets: tuple[int, ...] | None = None,
+        index_dtype: PsIntegerType = PsSignedIntegerType(64),
+    ):
+        self._name = name
+
+        if offsets is not None and len(offsets) != dim:
+            raise ValueError(f"Must have exactly {dim} offsets.")
+
+        self._shape = tuple(
+            PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim)
+        )
+        self._strides = tuple(
+            PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim)
+        )
+        self._element_type = element_type
+
+        if offsets is None:
+            offsets = (0,) * dim
+
+        self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets)
+
+    @property
+    def name(self):
+        return self._name
+
+    @property
+    def shape(self):
+        return self._shape
+
+    @property
+    def strides(self):
+        return self._strides
+
+    @property
+    def element_type(self):
+        return self._element_type
+
+    @property
+    def offsets(self) -> tuple[PsTypedConstant, ...]:
+        return self._offsets
+
+
+class PsArrayAssocVar(PsTypedVariable, ABC):
+    """A variable that is associated to an array.
+
+    Instances of this class represent pointers and indexing information bound
+    to a particular array.
+    """
+
+    def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
+        super().__init__(name, dtype)
+        self._array = array
+
+    @property
+    def array(self) -> PsLinearizedArray:
+        return self._array
+
+
+class PsArrayBasePointer(PsArrayAssocVar):
+    def __init__(self, name: str, array: PsLinearizedArray):
+        dtype = PsPointerType(array.element_type)
+        super().__init__(name, dtype, array)
+
+        self._array = array
+
+
+class PsArrayShapeVar(PsArrayAssocVar):
+    def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType):
+        name = f"{array}_size{dimension}"
+        super().__init__(name, dtype, array)
+
+
+class PsArrayStrideVar(PsArrayAssocVar):
+    def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType):
+        name = f"{array}_size{dimension}"
+        super().__init__(name, dtype, array)
+
+
+class PsArrayAccess(pb.Subscript):
+    def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
+        super(PsArrayAccess, self).__init__(base_ptr, index)
+        self._base_ptr = base_ptr
+        self._index = index
+
+    @property
+    def base_ptr(self):
+        return self._base_ptr
+
+    @property
+    def array(self) -> PsLinearizedArray:
+        return self._base_ptr.array
+
+    @property
+    def dtype(self) -> PsAbstractType:
+        """Data type of this expression, i.e. the element type of the underlying array"""
+        return self._base_ptr.array.element_type
+
+
+# class PsIterationDomain:
+#     """A factory for arrays spanning a given iteration domain."""
+
+#     def __init__(
+#         self,
+#         id: str,
+#         dim: int | None = None,
+#         fixed_shape: tuple[int, ...] | None = None,
+#         index_dtype: PsIntegerType = PsSignedIntegerType(64),
+#     ):
+#         if fixed_shape is not None:
+#             if dim is not None and len(fixed_shape) != dim:
+#                 raise ValueError(
+#                     "If both `dim` and `fixed_shape` are specified, `fixed_shape` must have exactly `dim` entries."
+#                 )
+
+#             shape = tuple(PsTypedConstant(s, index_dtype) for s in fixed_shape)
+#         elif dim is not None:
+#             shape = tuple(
+#                 PsTypedVariable(f"{id}_shape_{d}", index_dtype) for d in range(dim)
+#             )
+#         else:
+#             raise ValueError("Either `fixed_shape` or `dim` must be specified.")
+
+#         self._domain_shape: tuple[VarOrConstant, ...] = shape
+#         self._index_dtype = index_dtype
+
+#         self._archetype_array: PsLinearizedArray | None = None
+
+#         self._constraints: list[PsParamConstraint] = []
+
+#     @property
+#     def dim(self) -> int:
+#         return len(self._domain_shape)
+
+#     @property
+#     def shape(self) -> tuple[VarOrConstant, ...]:
+#         return self._domain_shape
+
+#     def create_array(self, ghost_layers: int = 0):
diff --git a/pystencils/nbackend/ast/constraints.py b/pystencils/nbackend/ast/constraints.py
new file mode 100644
index 000000000..68cbe347a
--- /dev/null
+++ b/pystencils/nbackend/ast/constraints.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+
+import pymbolic.primitives as pb
+from pymbolic.mapper.c_code import CCodeMapper
+
+
+@dataclass
+class PsParamConstraint:
+    condition: pb.Comparison
+    message: str = ""
+
+    def print(self):
+        return CCodeMapper()(self.condition)
diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py
index aaf1ac5e5..fccd4f12a 100644
--- a/pystencils/nbackend/ast/kernelfunction.py
+++ b/pystencils/nbackend/ast/kernelfunction.py
@@ -1,13 +1,68 @@
-from typing import Sequence
+from __future__ import annotations
 
 from typing import Generator
+from dataclasses import dataclass
+
+from pymbolic.mapper.dependency import DependencyMapper
+
 from .nodes import PsAstNode, PsBlock, failing_cast
+from .constraints import PsParamConstraint
 from ..typed_expressions import PsTypedVariable
+from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
+from ..exceptions import PsInternalCompilerError
 from ...enums import Target
 
 
+@dataclass
+class PsKernelParametersSpec:
+    """Specification of a kernel function's parameters.
+    
+    Contains:
+        - Verbatim parameter list, a list of `PsTypedVariables`
+        - List of Arrays used in the kernel, in canonical order
+        - A set of constraints on the kernel parameters, used to e.g. express relations of array
+          shapes, alignment properties, ...
+    """
+
+    params: tuple[PsTypedVariable, ...]
+    arrays: tuple[PsLinearizedArray, ...]
+    constraints: tuple[PsParamConstraint, ...]
+
+    def params_for_array(self, arr: PsLinearizedArray):
+        def pred(p: PsTypedVariable):
+            return isinstance(p, PsArrayAssocVar) and p.array == arr
+        
+        return tuple(filter(pred, self.params))
+    
+    def __post_init__(self):
+        dep_mapper = DependencyMapper(False, False, False, False)
+
+        #   Check constraints
+        for constraint in self.constraints:
+            variables: set[PsTypedVariable] = dep_mapper(constraint.condition)
+            for var in variables:
+                if isinstance(var, PsArrayAssocVar):
+                    if var.array in self.arrays:
+                        continue
+
+                elif var in self.params:
+                    continue
+
+                else:
+                    raise PsInternalCompilerError(
+                        "Constrained parameter was neither contained in kernel parameter list "
+                        "nor associated with a kernel array.\n"
+                        f"    Parameter: {var}\n"
+                        f"    Constraint: {constraint.condition}"
+                    )
+
+
 class PsKernelFunction(PsAstNode):
-    """A complete pystencils kernel function."""
+    """A pystencils kernel function.
+    
+    Objects of this class represent a full pystencils kernel and should provide all information required for
+    export, compilation, and inclusion of the kernel into a runtime system.
+    """
 
     __match_args__ = ("body",)
 
@@ -16,6 +71,8 @@ class PsKernelFunction(PsAstNode):
         self._target = target
         self._name = name
 
+        self._constraints: list[PsParamConstraint] = []
+
     @property
     def target(self) -> Target:
         """See pystencils.Target"""
@@ -53,7 +110,10 @@ class PsKernelFunction(PsAstNode):
             raise IndexError(f"Child index out of bounds: {idx}")
         self._body = failing_cast(PsBlock, c)
 
-    def get_parameters(self) -> Sequence[PsTypedVariable]:
+    def add_constraints(self, *constraints: PsParamConstraint):
+        self._constraints += constraints
+
+    def get_parameters(self) -> PsKernelParametersSpec:
         """Collect the list of parameters to this function.
 
         This function performs a full traversal of the AST.
@@ -61,5 +121,8 @@ class PsKernelFunction(PsAstNode):
         """
         from .analysis import UndefinedVariablesCollector
 
-        params = UndefinedVariablesCollector().collect(self)
-        return sorted(params, key=lambda p: p.name)
+        params_set = UndefinedVariablesCollector().collect(self)
+        params_list = sorted(params_set, key=lambda p: p.name)
+
+        arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
+        return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints))
diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py
index da3ead273..2944e073d 100644
--- a/pystencils/nbackend/ast/nodes.py
+++ b/pystencils/nbackend/ast/nodes.py
@@ -1,9 +1,10 @@
 from __future__ import annotations
-from typing import Sequence, Generator, Iterable, cast
+from typing import Sequence, Generator, Iterable, cast, TypeAlias
 
 from abc import ABC, abstractmethod
 
-from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
+from ..typed_expressions import PsTypedVariable, ExprOrConstant
+from ..arrays import PsArrayAccess
 from .util import failing_cast
 
 
@@ -123,6 +124,10 @@ class PsSymbolExpr(PsLvalueExpr):
         self._expr = symbol
 
 
+PsLvalue: TypeAlias = PsTypedVariable | PsArrayAccess
+"""Types of expressions that may occur on the left-hand side of assignments."""
+
+
 class PsAssignment(PsAstNode):
     __match_args__ = (
         "lhs",
diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py
index 6872ad9df..58a61d579 100644
--- a/pystencils/nbackend/c_printer.py
+++ b/pystencils/nbackend/c_printer.py
@@ -27,8 +27,8 @@ class CPrinter:
     
     @visit.case(PsKernelFunction)
     def function(self, func: PsKernelFunction) -> str:
-        params = func.get_parameters()
-        params_str = ", ".join(f"{p.dtype} {p.name}" for p in params)
+        params_spec = func.get_parameters()
+        params_str = ", ".join(f"{p.dtype} {p.name}" for p in params_spec.params)
         decl = f"FUNC_PREFIX void {func.name} ({params_str})"
         body = self.visit(func.body)
         return f"{decl}\n{body}"
diff --git a/pystencils/nbackend/sympy_mapper.py b/pystencils/nbackend/sympy_mapper.py
index 380ed699d..fecad4c63 100644
--- a/pystencils/nbackend/sympy_mapper.py
+++ b/pystencils/nbackend/sympy_mapper.py
@@ -4,7 +4,8 @@ from pystencils.typing import TypedSymbol
 from pystencils.typing.typed_sympy import SHAPE_DTYPE
 from .ast.nodes import PsAssignment, PsSymbolExpr
 from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
-from .typed_expressions import PsArrayBasePointer, PsLinearizedArray, PsTypedVariable, PsArrayAccess
+from .typed_expressions import PsTypedVariable
+from .arrays import PsArrayBasePointer, PsLinearizedArray, PsArrayAccess
 
 CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)]
 
@@ -44,7 +45,9 @@ class PystencilsToPymbolicMapper(SympyToPymbolicMapper):
         array = PsLinearizedArray(name, shape, strides, dtype)
 
         ptr = PsArrayBasePointer(expr.name, array)
-        index = sum([ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)])
+        index = sum(
+            [ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)]
+        )
         index = self.rec(index)
 
         return PsSymbolExpr(PsArrayAccess(ptr, index))
diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py
index ee3536a52..e878fbfce 100644
--- a/pystencils/nbackend/typed_expressions.py
+++ b/pystencils/nbackend/typed_expressions.py
@@ -1,15 +1,12 @@
 from __future__ import annotations
 
-from functools import reduce
-from typing import TypeAlias, Union, Any, Tuple
+from typing import TypeAlias, Any
 
 import pymbolic.primitives as pb
 
 from .types import (
     PsAbstractType,
-    PsScalarType,
     PsNumericType,
-    PsPointerType,
     constify,
     PsTypeError,
 )
@@ -25,91 +22,6 @@ class PsTypedVariable(pb.Variable):
         return self._dtype
 
 
-class PsArray:
-    def __init__(
-        self,
-        name: str,
-        length: pb.Expression,
-        element_type: PsScalarType,  # todo Frederik: is PsScalarType correct?
-    ):
-        self._name = name
-        self._length = length
-        self._element_type = element_type
-
-    @property
-    def name(self):
-        return self._name
-
-    @property
-    def length(self):
-        return self._length
-
-    @property
-    def element_type(self):
-        return self._element_type
-
-
-class PsLinearizedArray(PsArray):
-    """N-dimensional contiguous array"""
-
-    def __init__(
-        self,
-        name: str,
-        shape: Tuple[pb.Expression, ...],
-        strides: Tuple[pb.Expression],
-        element_type: PsScalarType,
-    ):
-        length = reduce(lambda x, y: x * y, shape)
-        super().__init__(name, length, element_type)
-
-        self._shape = shape
-        self._strides = strides
-
-    @property
-    def shape(self):
-        return self._shape
-
-    @property
-    def strides(self):
-        return self._strides
-
-
-class PsArrayBasePointer(PsTypedVariable):
-    def __init__(self, name: str, array: PsArray):
-        dtype = PsPointerType(array.element_type)
-        super().__init__(name, dtype)
-
-        self._array = array
-
-    @property
-    def array(self):
-        return self._array
-
-
-class PsArrayAccess(pb.Subscript):
-    def __init__(self, base_ptr: PsArrayBasePointer, index: pb.Expression):
-        super(PsArrayAccess, self).__init__(base_ptr, index)
-        self._base_ptr = base_ptr
-        self._index = index
-
-    @property
-    def base_ptr(self):
-        return self._base_ptr
-
-    # @property
-    # def index(self):
-    #     return self._index
-
-    @property
-    def array(self) -> PsArray:
-        return self._base_ptr.array
-
-    @property
-    def dtype(self) -> PsAbstractType:
-        """Data type of this expression, i.e. the element type of the underlying array"""
-        return self._base_ptr.array.element_type
-
-
 class PsTypedConstant:
     """Represents typed constants occuring in the pystencils AST.
 
@@ -290,9 +202,7 @@ class PsTypedConstant:
 
 pb.register_constant_class(PsTypedConstant)
 
-
-PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
-"""Types of expressions that may occur on the left-hand side of assignments."""
-
 ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant
 """Required since `PsTypedConstant` does not derive from `pb.Expression`."""
+
+VarOrConstant: TypeAlias = PsTypedVariable | PsTypedConstant
-- 
GitLab