From 0aee9f49ffcbea19408207032af80c42c2bfd1ec Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 17 Jan 2024 19:00:53 +0100
Subject: [PATCH] simplified the previous chaos

---
 src/pystencils/nbackend/arrays.py             |  38 +++--
 .../__init__.py                               |   0
 .../nbackend/kernelcreation/context.py        | 159 ++++++++++++++++++
 .../nbackend/translation/context.py           |  61 -------
 .../nbackend/translation/field_array_pair.py  |  21 ---
 .../nbackend/translation/iteration_domain.py  | 130 --------------
 6 files changed, 181 insertions(+), 228 deletions(-)
 rename src/pystencils/nbackend/{translation => kernelcreation}/__init__.py (100%)
 create mode 100644 src/pystencils/nbackend/kernelcreation/context.py
 delete mode 100644 src/pystencils/nbackend/translation/context.py
 delete mode 100644 src/pystencils/nbackend/translation/field_array_pair.py
 delete mode 100644 src/pystencils/nbackend/translation/iteration_domain.py

diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index e206b28fd..369f3d19a 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -49,21 +49,14 @@ from abc import ABC
 
 import pymbolic.primitives as pb
 
-from .types import (
-    PsAbstractType,
-    PsScalarType,
-    PsPointerType,
-    PsIntegerType,
-    PsSignedIntegerType,
-    constify,
-)
+from .types import PsAbstractType, PsPointerType, PsIntegerType, PsSignedIntegerType
 
 from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
 
 
 class PsLinearizedArray:
     """Class to model N-dimensional contiguous arrays.
-    
+
     Memory Layout, Shape and Strides
     --------------------------------
 
@@ -126,17 +119,27 @@ class PsLinearizedArray:
     @property
     def element_type(self):
         return self._element_type
-    
+
     def _hashable_contents(self):
         """Contents by which to compare two instances of `PsLinearizedArray`.
-        
+
         Since equality checks on shape and stride variables internally check equality of their associated arrays,
         if these variables would occur in here, an infinite recursion would follow.
         Hence they are filtered and replaced by the ellipsis.
         """
-        shape_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._shape)
-        strides_clean = tuple((s if isinstance(s, PsTypedConstant) else ...) for s in self._strides)
-        return (self._name, self._element_type, self._index_dtype, shape_clean, strides_clean)
+        shape_clean = tuple(
+            (s if isinstance(s, PsTypedConstant) else ...) for s in self._shape
+        )
+        strides_clean = tuple(
+            (s if isinstance(s, PsTypedConstant) else ...) for s in self._strides
+        )
+        return (
+            self._name,
+            self._element_type,
+            self._index_dtype,
+            shape_clean,
+            strides_clean,
+        )
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsLinearizedArray):
@@ -147,6 +150,7 @@ class PsLinearizedArray:
     def __hash__(self) -> int:
         return hash(self._hashable_contents())
 
+
 class PsArrayAssocVar(PsTypedVariable, ABC):
     """A variable that is associated to an array.
 
@@ -185,10 +189,11 @@ class PsArrayBasePointer(PsArrayAssocVar):
 
 class PsArrayShapeVar(PsArrayAssocVar):
     """Variable that represents an array's shape in one coordinate.
-    
+
     Do not instantiate this class yourself, but only use its instances
     as provided by `PsLinearizedArray.shape`.
     """
+
     init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
     __match_args__ = ("array", "coordinate", "dtype")
 
@@ -207,10 +212,11 @@ class PsArrayShapeVar(PsArrayAssocVar):
 
 class PsArrayStrideVar(PsArrayAssocVar):
     """Variable that represents an array's stride in one coordinate.
-    
+
     Do not instantiate this class yourself, but only use its instances
     as provided by `PsLinearizedArray.strides`.
     """
+
     init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
     __match_args__ = ("array", "coordinate", "dtype")
 
diff --git a/src/pystencils/nbackend/translation/__init__.py b/src/pystencils/nbackend/kernelcreation/__init__.py
similarity index 100%
rename from src/pystencils/nbackend/translation/__init__.py
rename to src/pystencils/nbackend/kernelcreation/__init__.py
diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py
new file mode 100644
index 000000000..c13f12da6
--- /dev/null
+++ b/src/pystencils/nbackend/kernelcreation/context.py
@@ -0,0 +1,159 @@
+from __future__ import annotations
+from typing import cast
+from dataclasses import dataclass
+
+from abc import ABC
+
+from ...field import Field
+from ...typing import TypedSymbol, BasicType
+
+from ..arrays import PsLinearizedArray, PsArrayBasePointer
+from ..types import PsIntegerType
+from ..types.quick import make_type
+from ..typed_expressions import PsTypedVariable, VarOrConstant
+from ..constraints import PsKernelConstraint
+
+
+@dataclass
+class PsFieldArrayPair:
+    field: Field
+    array: PsLinearizedArray
+    base_ptr: PsArrayBasePointer
+
+
+class IterationSpace(ABC):
+    """Represents the n-dimensonal iteration space of a pystencils kernel.
+
+    Instances of this class represent the kernel's iteration region during translation from
+    SymPy, before any indexing sources are generated. It provides the counter symbols which
+    should be used to translate field accesses to array accesses.
+
+    There are two types of iteration spaces, modelled by subclasses:
+     - The full iteration space translates to an n-dimensional loop nest or the corresponding device
+       indexing scheme.
+     - The sparse iteration space translates to a single loop over an index list which in turn provides the
+       spatial indices.
+    """
+
+    def __init__(self, spatial_index_variables: tuple[PsTypedVariable, ...]):
+        if len(spatial_index_variables) == 0:
+            raise ValueError("Iteration space must be at least one-dimensional.")
+
+        self._spatial_index_vars = spatial_index_variables
+
+    def get_spatial_index(self, coordinate: int) -> PsTypedVariable:
+        return self._spatial_index_vars[coordinate]
+
+
+class FullIterationSpace(IterationSpace):
+    def __init__(
+        self,
+        lower: tuple[VarOrConstant, ...],
+        upper: tuple[VarOrConstant, ...],
+        counters: tuple[PsTypedVariable, ...],
+    ):
+        if not (len(lower) == len(upper) == len(counters)):
+            raise ValueError(
+                "Lower and upper iteration limits and counters must have the same shape."
+            )
+
+        super().__init__(counters)
+
+        self._lower = lower
+        self._upper = upper
+        self._counters = counters
+
+    @property
+    def lower(self):
+        return self._lower
+
+    @property
+    def upper(self):
+        return self._upper
+
+
+class SparseIterationSpace(IterationSpace):
+    def __init__(self, spatial_index_variables: tuple[PsTypedVariable, ...]):
+        super().__init__(spatial_index_variables)
+        # todo
+
+
+class KernelCreationContext:
+    """Manages the translation process from the SymPy frontend to the backend AST.
+
+    It does the following things:
+
+      - Default data types: The context knows the data types that should be applied by default
+        to SymPy expressions.
+      - Management of fields. The context manages all mappings from front-end `Field`s to their
+        underlying `PsLinearizedArray`s.
+      - Collection of constraints. All constraints that arise during translation are collected in the
+        context, and finally attached to the kernel function object once translation is complete.
+
+    Data Types
+    ----------
+
+     - The `index_dtype` is the data type used throughout translation for all loop counters and array indexing.
+     - The `default_numeric_dtype` is the data type assigned by default to all symbols occuring in SymPy assignments
+
+    Fields and Arrays
+    -----------------
+
+    There's several types of fields that need to be mapped to arrays.
+
+    - `FieldType.GENERIC` corresponds to domain fields.
+      Domain fields can only be accessed by relative offsets, and therefore must always
+      be associated with an iteration space that provides a spatial index tuple.
+    - `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index.
+      If there is at least one indexed field present there must also exist an index source for that field
+      (loop or device indexing).
+      An indexed field may itself be an index source for domain fields.
+    - `FieldType.BUFFER` are 1D arrays whose indices must be incremented with each access.
+      Within a domain, a buffer may be either written to or read from, never both.
+
+
+    In the translator, frontend fields and backend arrays are managed together using the `PsFieldArrayPair` class.
+    """
+
+    def __init__(self, index_dtype: PsIntegerType):
+        self._index_dtype = index_dtype
+        self._constraints: list[PsKernelConstraint] = []
+
+    @property
+    def index_dtype(self) -> PsIntegerType:
+        return self._index_dtype
+
+    def add_constraints(self, *constraints: PsKernelConstraint):
+        self._constraints += constraints
+
+    @property
+    def constraints(self) -> tuple[PsKernelConstraint, ...]:
+        return tuple(self._constraints)
+
+    def add_field(self, field: Field) -> PsFieldArrayPair:
+        arr_shape = tuple(
+            (
+                Ellipsis if isinstance(s, TypedSymbol) else s
+            )  # TODO: Field should also use ellipsis
+            for s in field.shape
+        )
+
+        arr_strides = tuple(
+            (
+                Ellipsis if isinstance(s, TypedSymbol) else s
+            )  # TODO: Field should also use ellipsis
+            for s in field.strides
+        )
+
+        # TODO: frontend should use new type system
+        element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type)
+
+        arr = PsLinearizedArray(
+            field.name, element_type, arr_shape, arr_strides, self.index_dtype
+        )
+
+        fa_pair = PsFieldArrayPair(
+            field=field, array=arr, base_ptr=PsArrayBasePointer("arr_data", arr)
+        )
+
+        return fa_pair
diff --git a/src/pystencils/nbackend/translation/context.py b/src/pystencils/nbackend/translation/context.py
deleted file mode 100644
index cc9cfc0fc..000000000
--- a/src/pystencils/nbackend/translation/context.py
+++ /dev/null
@@ -1,61 +0,0 @@
-from ...field import Field
-from ..arrays import PsLinearizedArray, PsArrayBasePointer
-from ..types import PsIntegerType
-from ..constraints import PsKernelConstraint
-
-from .iteration_domain import PsIterationDomain
-
-class PsTranslationContext:
-    """The `PsTranslationContext` manages the translation process from the SymPy frontend
-    to the backend AST.
-    
-    It does the following things:
-
-      - Default data types: The context knows the data types that should be applied by default
-        to SymPy expressions.
-      - Management of fields. The context manages all mappings from front-end `Field`s to their
-        underlying `PsLinearizedArray`s.
-      - Collection of constraints. All constraints that arise during translation are collected in the
-        context, and finally attached to the kernel function object once translation is complete.
-    
-    Data Types
-    ----------
-
-     - The `index_dtype` is the data type used throughout translation for all loop counters and array indexing.
-     - The `default_numeric_dtype` is the data type assigned by default to all symbols occuring in SymPy assignments
-    
-    Fields and Arrays
-    -----------------
-
-    There's several types of fields that need to be mapped to arrays.
-
-    - `FieldType.GENERIC` corresponds to domain fields. 
-      Domain fields can only be accessed by relative offsets, and therefore must always
-      be associated with an *iteration domain* that provides a spatial index tuple.
-      All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers.
-    - `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index.
-      If there is at least one indexed field present there must also exist an index source for that field
-      (loop or device indexing).
-      An indexed field may itself be an index source for domain fields.
-    - `FieldType.BUFFER` are 1D arrays whose indices must be incremented with each access.
-      Within a domain, a buffer may be either written to or read from, never both.
-
-
-    In the translator, frontend fields and backend arrays are managed together using the `PsFieldArrayPair` class.
-    """
-
-    def __init__(self, index_dtype: PsIntegerType):
-        self._index_dtype = index_dtype
-        self._constraints: list[PsKernelConstraint] = []
-
-    @property
-    def index_dtype(self) -> PsIntegerType:
-        return self._index_dtype
-    
-    def add_constraints(self, *constraints: PsKernelConstraint):
-        self._constraints += constraints
-
-    @property
-    def constraints(self) -> tuple[PsKernelConstraint, ...]:
-        return tuple(self._constraints)
-
diff --git a/src/pystencils/nbackend/translation/field_array_pair.py b/src/pystencils/nbackend/translation/field_array_pair.py
deleted file mode 100644
index 720b5c1c7..000000000
--- a/src/pystencils/nbackend/translation/field_array_pair.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from dataclasses import dataclass
-
-from ...field import Field
-from ..arrays import PsLinearizedArray, PsArrayBasePointer
-from ..types import PsIntegerType
-from ..constraints import PsKernelConstraint
-
-from .iteration_domain import PsIterationDomain
-
-@dataclass
-class PsFieldArrayPair:
-    field: Field
-    array: PsLinearizedArray
-    base_ptr: PsArrayBasePointer
-
-
-@dataclass
-class PsDomainFieldArrayPair(PsFieldArrayPair):
-    ghost_layers: int
-    interior_base_ptr: PsArrayBasePointer
-    domain: PsIterationDomain
diff --git a/src/pystencils/nbackend/translation/iteration_domain.py b/src/pystencils/nbackend/translation/iteration_domain.py
deleted file mode 100644
index 990a4ff67..000000000
--- a/src/pystencils/nbackend/translation/iteration_domain.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, cast
-from types import EllipsisType
-
-from ...field import Field
-from ...typing import TypedSymbol, BasicType
-from ..arrays import PsLinearizedArray, PsArrayBasePointer
-from ..types.quick import make_type
-from ..typed_expressions import PsTypedVariable, PsTypedConstant, VarOrConstant
-from .field_array_pair import PsDomainFieldArrayPair
-
-if TYPE_CHECKING:
-    from .context import PsTranslationContext
-
-class PsIterationDomain:
-    """Represents the n-dimensonal spatial iteration domain of a pystencils kernel.
-    
-    Domain Shape
-    ------------
-
-    A domain may have either constant or variable, n-dimensional shape, where n = 1, 2, 3.
-    If the shape is variable, the domain object manages variables for each shape entry.
-
-    The domain provides index variables for each dimension which may be used to access fields
-    associated with the domain.
-    In the kernel, these index variables must be provided by some index source.
-    Index sources differ between two major types of domains: full and sparse domains.
-
-    In a full domain, it is guaranteed that each interior point is processed by the kernel.
-    The index source may therefore be a full n-fold loop nest, or a device index calculation.
-
-    In a sparse domain, the iteration is controlled by an index vector, which acts as the index
-    source.
-
-    Arrays
-    ------
-
-    Any number of domain arrays may be associated with each domain.
-    Each array is annotated with a number of ghost layers for each spatial coordinate.
-
-    ### Shape Compatibility
-
-    When an array is associated with a domain, it must be ensured that the array's shape
-    is compatible with the domain.
-    The first n shape entries are considered the array's spatial shape.
-    These spatial shapes, after subtracting ghost layers, must all be equal, and are further
-    constrained by a constant domain shape.
-    For each spatial coordinate, shape compatibility is ensured as described by the following table.
-
-    |                           |  Constant Array Shape       |   Variable Array Shape |
-    |---------------------------|-----------------------------|------------------------|
-    | **Constant Domain Shape** | Compile-Time Equality Check |  Kernel Constraints    |
-    | **Variable Domain Shape** | Invalid, Compiler Error     |  Kernel Constraints    |
-
-    ### Base Pointers and Array Accesses
-
-    In the kernel's public interface, each array is represented at least through its base pointer,
-    which represents the starting address of the array's data in memory.
-    Since the iteration domain models arrays as being surrounded by ghost layers, it provides for each
-    array a second, *interior* base pointer, which points to the first interior point after skipping the
-    ghost layers, e.g. in three dimensions with one index dimension:
-
-    ```
-    addr(interior_base_ptr[0, 0, 0, 0]) == addr(base_ptr[gls, gls, gls, 0])
-    ```
-
-    To access domain arrays using the domain's index variables, the interior base pointer should be used,
-    since the domain index variables always count up from zero.
-
-    """
-
-    def __init__(self, ctx: PsTranslationContext, shape: tuple[int | EllipsisType, ...]):
-        self._ctx = ctx
-        
-        if len(shape) == 0:
-            raise ValueError("Domain shape must be at least one-dimensional.")
-        
-        if len(shape) > 3:
-            raise ValueError("Iteration domain can be at most three-dimensional.")
-        
-        self._shape: tuple[VarOrConstant, ...] = tuple(
-            (
-                PsTypedVariable(f"domain_size_{i}", self._ctx.index_dtype)
-                if s == Ellipsis
-                else PsTypedConstant(s, self._ctx.index_dtype)
-            )
-            for i, s in enumerate(shape)
-        )
-
-        self._archetype_field: PsDomainFieldArrayPair | None = None
-        self._fields: dict[str, PsDomainFieldArrayPair] = dict()
-
-    @property
-    def shape(self) -> tuple[VarOrConstant, ...]:
-        return self._shape
-    
-    def add_field(self, field: Field, ghost_layers: int) -> PsDomainFieldArrayPair:
-        arr_shape = tuple(
-            (Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis
-            for s in field.shape
-        )
-
-        arr_strides = tuple(
-            (Ellipsis if isinstance(s, TypedSymbol) else s) # TODO: Field should also use ellipsis
-            for s in field.strides
-        )
-
-        # TODO: frontend should use new type system
-        element_type = make_type(cast(BasicType, field.dtype).numpy_dtype.type) 
-
-        arr = PsLinearizedArray(field.name, element_type, arr_shape, arr_strides, self._ctx.index_dtype)
-
-        fa_pair = PsDomainFieldArrayPair(
-            field=field,
-            array=arr,
-            base_ptr=PsArrayBasePointer("arr_data", arr),
-            ghost_layers=ghost_layers,
-            interior_base_ptr=PsArrayBasePointer("arr_interior_data", arr),
-            domain=self
-        )
-        
-        #   Check shape compatibility
-        #   TODO
-        for domain_s, field_s in zip(self.shape, field.shape):
-            if isinstance(domain_s, PsTypedConstant):
-                pass
-
-        raise NotImplementedError()
-
-- 
GitLab