From bddd3a37e8b55bc62080da4fcea10bd982356cc2 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 19 Feb 2024 16:07:35 +0100
Subject: [PATCH] refactor field and array handling in context

---
 .../backend/kernelcreation/context.py         | 132 ++++++++++++------
 tests/nbackend/kernelcreation/test_options.py |   2 +-
 2 files changed, 87 insertions(+), 47 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 0982e8455..68fdfd9c5 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+from types import EllipsisType
+
 from ...field import Field, FieldType
 from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType
 from ..arrays import PsLinearizedArray
@@ -43,15 +45,18 @@ class KernelCreationContext:
     or full iteration space.
     """
 
-    def __init__(self,
-                 default_dtype: PsNumericType = PbDefaults.numeric_dtype,
-                 index_dtype: PsIntegerType = PbDefaults.index_dtype):
+    def __init__(
+        self,
+        default_dtype: PsNumericType = PbDefaults.numeric_dtype,
+        index_dtype: PsIntegerType = PbDefaults.index_dtype,
+    ):
         self._default_dtype = default_dtype
         self._index_dtype = index_dtype
-        self._arrays: dict[Field, PsLinearizedArray] = dict()
         self._constraints: list[PsKernelConstraint] = []
 
+        self._field_arrays: dict[Field, PsLinearizedArray] = dict()
         self._fields_collection = FieldsInKernel()
+        
         self._ispace: IterationSpace | None = None
 
     @property
@@ -76,7 +81,22 @@ class KernelCreationContext:
         return self._fields_collection
 
     def add_field(self, field: Field):
-        """Add the given field to the context's fields collection"""
+        """Add the given field to the context's fields collection.
+
+        This method adds the passed ``field`` to the context's field collection, which is
+        accesible through the `fields` member, and creates an array representation of the field,
+        which is retrievable through `get_array`.
+        Before adding the field to the collection, various sanity and constraint checks are applied.
+        """
+
+        if field in self._field_arrays:
+            #   Field was already added
+            return
+
+        arr_shape: list[EllipsisType | int] | None = None
+        arr_strides: list[EllipsisType | int] | None = None
+
+        #   Check field constraints and add to collection
         match field.field_type:
             case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX:
                 self._fields_collection.domain_fields.add(field)
@@ -87,6 +107,23 @@ class KernelCreationContext:
                         f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. "
                         "Buffer fields must be one-dimensional."
                     )
+
+                if field.index_dimensions > 1:
+                    raise KernelConstraintsError(
+                        f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. "
+                        "Buffer fields can have at most one index dimension."
+                    )
+
+                num_entries = field.index_shape[0] if field.index_shape else 1
+                if not isinstance(num_entries, int):
+                    raise KernelConstraintsError(
+                        f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. "
+                        "Buffer fields cannot have variable index shape."
+                    )
+
+                arr_shape = [..., num_entries]
+                arr_strides = [num_entries, 1]
+
                 self._fields_collection.buffer_fields.add(field)
 
             case FieldType.INDEXED:
@@ -103,48 +140,51 @@ class KernelCreationContext:
             case _:
                 assert False, "unreachable code"
 
-    def get_array(self, field: Field) -> PsLinearizedArray:
-        if field not in self._arrays:
-            if field.field_type == FieldType.BUFFER:
-                #   Buffers are always contiguous
-                assert field.spatial_dimensions == 1
-                assert field.index_dimensions <= 1
-                num_entries = field.index_shape[0] if field.index_shape else 1
+        #   For non-buffer fields, determine shape and strides
 
-                arr_shape = [..., num_entries]
-                arr_strides = [num_entries, 1]
-            else:
-                arr_shape = [
-                    (
-                        Ellipsis if isinstance(s, TypedSymbol) else s
-                    )  # TODO: Field should also use ellipsis
-                    for s in field.shape
-                ]
-
-                arr_strides = [
-                    (
-                        Ellipsis if isinstance(s, TypedSymbol) else s
-                    )  # TODO: Field should also use ellipsis
-                    for s in field.strides
-                ]
-
-                # The frontend doesn't quite agree with itself on how to model
-                # fields with trivial index dimensions. Sometimes the index_shape is empty,
-                # sometimes its (1,). This is canonicalized here.
-                if not field.index_shape:
-                    arr_shape += [1]
-                    arr_strides += [1]
-
-            assert isinstance(field.dtype, (BasicType, StructType))
-            element_type = make_type(field.dtype.numpy_dtype)
-
-            arr = PsLinearizedArray(
-                field.name, element_type, arr_shape, arr_strides, self.index_dtype
-            )
-
-            self._arrays[field] = arr
-
-        return self._arrays[field]
+        if arr_shape is None:
+            arr_shape = [
+                (
+                    Ellipsis if isinstance(s, TypedSymbol) else s
+                )  # TODO: Field should also use ellipsis
+                for s in field.shape
+            ]
+
+            arr_strides = [
+                (
+                    Ellipsis if isinstance(s, TypedSymbol) else s
+                )  # TODO: Field should also use ellipsis
+                for s in field.strides
+            ]
+
+            # The frontend doesn't quite agree with itself on how to model
+            # fields with trivial index dimensions. Sometimes the index_shape is empty,
+            # sometimes its (1,). This is canonicalized here.
+            if not field.index_shape:
+                arr_shape += [1]
+                arr_strides += [1]
+
+        #   Add array
+        assert arr_strides is not None
+
+        assert isinstance(field.dtype, (BasicType, StructType))
+        element_type = make_type(field.dtype.numpy_dtype)
+
+        arr = PsLinearizedArray(
+            field.name, element_type, arr_shape, arr_strides, self.index_dtype
+        )
+
+        self._field_arrays[field] = arr
+
+    def get_array(self, field: Field) -> PsLinearizedArray:
+        """Retrieve the underlying array for a given field.
+
+        If the given field was not previously registered using `add_field`,
+        this method internally calls `add_field` to check the field for consistency.
+        """
+        if field not in self._field_arrays:
+            self.add_field(field)
+        return self._field_arrays[field]
 
     #   Iteration Space
 
diff --git a/tests/nbackend/kernelcreation/test_options.py b/tests/nbackend/kernelcreation/test_options.py
index 726ee8def..8d145da75 100644
--- a/tests/nbackend/kernelcreation/test_options.py
+++ b/tests/nbackend/kernelcreation/test_options.py
@@ -2,7 +2,7 @@ import pytest
 
 from pystencils.field import Field, FieldType
 from pystencils.backend.types.quick import *
-from pystencils.kernelcreation import (
+from pystencils.config import (
     CreateKernelConfig,
     PsOptionsError,
 )
-- 
GitLab