From 4d13952d1427ad9a34dbc933898bca19a7feecec Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 17 Jan 2024 12:06:47 +0100
Subject: [PATCH] some cleanup. Notes on translation context.

---
 src/pystencils/nbackend/arrays.py             | 19 +--------
 src/pystencils/nbackend/jit/__init__.py       |  0
 .../nbackend/translation/context.py           | 41 +++++++++++++++++++
 tests/nbackend/test_constant_folding.py       |  4 ++
 4 files changed, 47 insertions(+), 17 deletions(-)
 create mode 100644 src/pystencils/nbackend/jit/__init__.py
 create mode 100644 src/pystencils/nbackend/translation/context.py

diff --git a/src/pystencils/nbackend/arrays.py b/src/pystencils/nbackend/arrays.py
index 5cab54abc..fd416a20c 100644
--- a/src/pystencils/nbackend/arrays.py
+++ b/src/pystencils/nbackend/arrays.py
@@ -56,7 +56,7 @@ from .types import (
     constify,
 )
 
-from .typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
+from .typed_expressions import PsTypedVariable, ExprOrConstant
 
 
 class PsLinearizedArray:
@@ -67,27 +67,19 @@ class PsLinearizedArray:
         name: str,
         element_type: PsScalarType,
         dim: int,
-        offsets: tuple[int, ...] | None = None,
         index_dtype: PsIntegerType = PsSignedIntegerType(64),
     ):
         self._name = name
 
-        if offsets is not None and len(offsets) != dim:
-            raise ValueError(f"Must have exactly {dim} offsets.")
-
         self._shape = tuple(
             PsArrayShapeVar(self, d, constify(index_dtype)) for d in range(dim)
         )
         self._strides = tuple(
             PsArrayStrideVar(self, d, constify(index_dtype)) for d in range(dim)
         )
-        self._element_type = element_type
-
-        if offsets is None:
-            offsets = (0,) * dim
 
+        self._element_type = element_type
         self._dim = dim
-        self._offsets = tuple(PsTypedConstant(o, index_dtype) for o in offsets)
         self._index_dtype = index_dtype
 
     @property
@@ -110,10 +102,6 @@ class PsLinearizedArray:
     def element_type(self):
         return self._element_type
 
-    @property
-    def offsets(self) -> tuple[PsTypedConstant, ...]:
-        return self._offsets
-
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsLinearizedArray):
             return False
@@ -122,13 +110,11 @@ class PsLinearizedArray:
             self._name,
             self._element_type,
             self._dim,
-            self._offsets,
             self._index_dtype,
         ) == (
             other._name,
             other._element_type,
             other._dim,
-            other._offsets,
             other._index_dtype,
         )
 
@@ -138,7 +124,6 @@ class PsLinearizedArray:
                 self._name,
                 self._element_type,
                 self._dim,
-                self._offsets,
                 self._index_dtype,
             )
         )
diff --git a/src/pystencils/nbackend/jit/__init__.py b/src/pystencils/nbackend/jit/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/pystencils/nbackend/translation/context.py b/src/pystencils/nbackend/translation/context.py
new file mode 100644
index 000000000..199315579
--- /dev/null
+++ b/src/pystencils/nbackend/translation/context.py
@@ -0,0 +1,41 @@
+
+class PsTranslationContext:
+    """The `PsTranslationContext` manages the translation process from the SymPy frontend
+    to the backend AST.
+    
+    It does the following things:
+
+      - Default data types: The context knows the data types that should be applied by default
+        to SymPy expressions.
+      - Management of fields. The context manages all mappings from front-end `Field`s to their
+        underlying `PsLinearizedArray`s.
+      - Collection of constraints. All constraints that arise during translation are collected in the
+        context, and finally attached to the kernel function object once translation is complete.
+    
+    Data Types
+    ----------
+
+     - The `index_dtype` is the data type used throughout translation for all loop counters and array indexing.
+     - The `default_numeric_dtype` is the data type assigned by default to all symbols occuring in SymPy assignments
+    
+    Fields and Arrays
+    -----------------
+
+    There's several types of fields that need to be mapped to arrays.
+
+    - `FieldType.GENERIC` corresponds to domain fields. 
+      Domain fields can only be accessed by relative offsets, and therefore must always
+      be associated with an *iteration domain* that provides a spatial index tuple.
+      All domain fields associated with the same domain must have the same spatial shape, modulo ghost layers.
+      A field and its array may be associated with multiple iteration domains.
+    - `FieldType.INDEXED` are 1D arrays of index structures. They must be accessed by a single running index.
+      If there is at least one indexed field present there must also exist an index source for that field
+      (loop or device indexing).
+      An indexed field may itself be an index source for domain fields.
+    - `FieldType.BUFFER` are 1D arrays whose indices must be incremented with each access.
+      Within a domain, a buffer may be either written to or read from, never both.
+
+
+    
+
+    """
diff --git a/tests/nbackend/test_constant_folding.py b/tests/nbackend/test_constant_folding.py
index 12bc15b69..c297534b4 100644
--- a/tests/nbackend/test_constant_folding.py
+++ b/tests/nbackend/test_constant_folding.py
@@ -27,6 +27,7 @@ def test_constant_folding_int(width):
 
     assert folder(expr) == PsTypedConstant(-53, SInt(width))
 
+@pytest.mark.xfail(reason="Current constant folder does not handle products")
 @pytest.mark.parametrize("width", (8, 16, 32, 64))
 def test_constant_folding_product(width):
     """
@@ -46,8 +47,11 @@ def test_constant_folding_product(width):
     assert folder(expr) == PsTypedConstant(-24, SInt(width))
 
 
+@pytest.mark.xfail(reason="Current constant folder does not handle divisions")
 @pytest.mark.parametrize("width", (32, 64))
 def test_constant_folding_float(width):
+    """The pymbolic constant folder does not fold quotients. This test serves as a reminder
+    to consider that behaviour"""
     folder = ConstantFoldingMapper()
 
     expr = pb.Quotient(
-- 
GitLab