From d976a971434080060e74c4712b8405a9fca53995 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 17 Jan 2024 20:10:40 +0100
Subject: [PATCH] updated freeze. Sketched kernel creation.

---
 .../nbackend/kernelcreation/context.py        | 17 ++--
 .../nbackend/kernelcreation/domain_kernels.py | 70 ++++++++++++++++
 .../nbackend/kernelcreation/freeze.py         | 84 +++++++++++++++++++
 src/pystencils/nbackend/sympy_mapper.py       | 53 ------------
 4 files changed, 166 insertions(+), 58 deletions(-)
 create mode 100644 src/pystencils/nbackend/kernelcreation/domain_kernels.py
 create mode 100644 src/pystencils/nbackend/kernelcreation/freeze.py
 delete mode 100644 src/pystencils/nbackend/sympy_mapper.py

diff --git a/src/pystencils/nbackend/kernelcreation/context.py b/src/pystencils/nbackend/kernelcreation/context.py
index c13f12da6..154e06aef 100644
--- a/src/pystencils/nbackend/kernelcreation/context.py
+++ b/src/pystencils/nbackend/kernelcreation/context.py
@@ -35,14 +35,15 @@ class IterationSpace(ABC):
        spatial indices.
     """
 
-    def __init__(self, spatial_index_variables: tuple[PsTypedVariable, ...]):
-        if len(spatial_index_variables) == 0:
+    def __init__(self, spatial_indices: tuple[PsTypedVariable, ...]):
+        if len(spatial_indices) == 0:
             raise ValueError("Iteration space must be at least one-dimensional.")
 
-        self._spatial_index_vars = spatial_index_variables
+        self._spatial_indices = spatial_indices
 
-    def get_spatial_index(self, coordinate: int) -> PsTypedVariable:
-        return self._spatial_index_vars[coordinate]
+    @property
+    def spatial_indices(self) -> tuple[PsTypedVariable, ...]:
+        return self._spatial_indices
 
 
 class FullIterationSpace(IterationSpace):
@@ -117,6 +118,7 @@ class KernelCreationContext:
 
     def __init__(self, index_dtype: PsIntegerType):
         self._index_dtype = index_dtype
+        self._arrays: dict[Field, PsFieldArrayPair] = dict()
         self._constraints: list[PsKernelConstraint] = []
 
     @property
@@ -156,4 +158,9 @@ class KernelCreationContext:
             field=field, array=arr, base_ptr=PsArrayBasePointer("arr_data", arr)
         )
 
+        self._arrays[field] = fa_pair
+
         return fa_pair
+
+    def get_array_descriptor(self, field: Field):
+        return self._arrays[field]
diff --git a/src/pystencils/nbackend/kernelcreation/domain_kernels.py b/src/pystencils/nbackend/kernelcreation/domain_kernels.py
new file mode 100644
index 000000000..d76ebbf2b
--- /dev/null
+++ b/src/pystencils/nbackend/kernelcreation/domain_kernels.py
@@ -0,0 +1,70 @@
+from types import EllipsisType
+
+from ...simp import AssignmentCollection
+from ...field import Field
+from ...kernel_contrains_check import KernelConstraintsCheck
+
+from ..types.quick import SInt
+from ..ast import PsBlock
+
+from .context import KernelCreationContext, FullIterationSpace
+from .freeze import FreezeExpressions
+
+# flake8: noqa
+def create_domain_kernel(assignments: AssignmentCollection):
+    #   TODO: Assemble configuration
+
+    #   1. Prepare context
+    ctx = KernelCreationContext(SInt(64))  # TODO: how to determine index type?
+
+    #   2. Check kernel constraints and collect all fields
+    check = KernelConstraintsCheck()  # TODO: config
+    check.visit(assignments)
+
+    all_fields: set[Field] = check.fields_written | check.fields_read
+
+    #   3. Register fields
+    for f in all_fields:
+        ctx.add_field(f)
+
+    #   All steps up to this point are the same in domain and indexed kernels;
+    #   the difference now comes with the iteration space.
+    #
+    #   Domain kernels create a full iteration space from their iteration slice
+    #   which is either explicitly given or computed from ghost layer requirements.
+    #   Indexed kernels, on the other hand, have to create a sparse iteration space
+    #   from one index list.
+
+    #   4. Create iteration space
+    ghost_layers: int = NotImplemented  # determine required ghost layers
+    common_shape: tuple[
+        int | EllipsisType, ...
+    ] = NotImplemented  # unify field shapes, add parameter constraints
+    #   don't forget custom iteration slice
+    ispace: FullIterationSpace = (
+        NotImplemented  # create from ghost layers and with given shape
+    )
+
+    #   5. Freeze assignments
+    #   This call is the same for both domain and indexed kernels
+    freeze = FreezeExpressions(ctx, ispace)
+    kernel_body: PsBlock = freeze(assignments)
+
+    #   6. Typify
+    #   Also the same for both types of kernels
+    #   determine_types(kernel_body)
+
+    #   Up to this point, all was target-agnostic, but now the target becomes relevant.
+    #   Here we might hand off the compilation to a target-specific part of the compiler
+    #   (CPU/CUDA/...), since these will likely also apply very different optimizations.
+
+    #   7. Add loops or device indexing
+    #   This step translates the iteration space to actual index calculation code and is once again
+    #   different in indexed and domain kernels.
+
+    #   8. Apply optimizations
+    #     - Vectorization
+    #     - OpenMP
+    #     - Loop Splitting, Tiling, Blocking
+
+    #   9. Create and return kernel function.
diff --git a/src/pystencils/nbackend/kernelcreation/freeze.py b/src/pystencils/nbackend/kernelcreation/freeze.py
new file mode 100644
index 000000000..7ae387993
--- /dev/null
+++ b/src/pystencils/nbackend/kernelcreation/freeze.py
@@ -0,0 +1,84 @@
+import pymbolic.primitives as pb
+from pymbolic.interop.sympy import SympyToPymbolicMapper
+
+from ...field import Field, FieldType
+
+from .context import KernelCreationContext, IterationSpace, SparseIterationSpace
+
+from ..ast.nodes import PsAssignment
+from ..types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
+from ..typed_expressions import PsTypedVariable
+from ..arrays import PsArrayAccess
+from ..exceptions import PsInternalCompilerError
+
+
+class FreezeExpressions(SympyToPymbolicMapper):
+    def __init__(self, ctx: KernelCreationContext, ispace: IterationSpace):
+        self._ctx = ctx
+        self._ispace = ispace
+
+    def map_Assignment(self, expr):  # noqa
+        lhs = self.rec(expr.lhs)
+        rhs = self.rec(expr.rhs)
+        return PsAssignment(lhs, rhs)
+
+    def map_BasicType(self, expr):
+        width = expr.numpy_dtype.itemsize * 8
+        const = expr.const
+        if expr.is_float():
+            return PsIeeeFloatType(width, const)
+        elif expr.is_uint():
+            return PsUnsignedIntegerType(width, const)
+        elif expr.is_int():
+            return PsSignedIntegerType(width, const)
+        else:
+            raise NotImplementedError("Data type not supported.")
+
+    def map_FieldShapeSymbol(self, expr):
+        dtype = self.rec(expr.dtype)
+        return PsTypedVariable(expr.name, dtype)
+
+    def map_TypedSymbol(self, expr):
+        dtype = self.rec(expr.dtype)
+        return PsTypedVariable(expr.name, dtype)
+
+    def map_Access(self, access: Field.Access):
+        field = access.field
+        array_desc = self._ctx.get_array_descriptor(field)
+        array = array_desc.array
+        ptr = array_desc.base_ptr
+
+        offsets: list[pb.Expression] = [self.rec(o) for o in access.offsets]
+        indices: list[pb.Expression] = [self.rec(o) for o in access.index]
+
+        if not access.is_absolute_access:
+            match field.field_type:
+                case FieldType.GENERIC:
+                    #   Add the iteration counters
+                    offsets = [
+                        i + o for i, o in zip(self._ispace.spatial_indices, offsets)
+                    ]
+                case FieldType.INDEXED:
+                    if isinstance(self._ispace, SparseIterationSpace):
+                        #   TODO: make sure index (and all offsets?) are zero
+                        #   TODO: Add sparse iteration counter
+                        raise NotImplementedError()
+                    else:
+                        raise PsInternalCompilerError(
+                            "Cannot translate index field access without a sparse iteration space."
+                        )
+                case FieldType.CUSTOM:
+                    raise ValueError("Custom fields support only absolute accesses.")
+                case unknown:
+                    raise NotImplementedError(
+                        f"Cannot translate accesses to field type {unknown} yet."
+                    )
+
+        index = pb.Sum(
+            tuple(
+                idx * stride
+                for idx, stride in zip(offsets + indices, array.strides, strict=True)
+            )
+        )
+
+        return PsArrayAccess(ptr, index)
diff --git a/src/pystencils/nbackend/sympy_mapper.py b/src/pystencils/nbackend/sympy_mapper.py
deleted file mode 100644
index fecad4c63..000000000
--- a/src/pystencils/nbackend/sympy_mapper.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from pymbolic.interop.sympy import SympyToPymbolicMapper
-
-from pystencils.typing import TypedSymbol
-from pystencils.typing.typed_sympy import SHAPE_DTYPE
-from .ast.nodes import PsAssignment, PsSymbolExpr
-from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
-from .typed_expressions import PsTypedVariable
-from .arrays import PsArrayBasePointer, PsLinearizedArray, PsArrayAccess
-
-CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)]
-
-
-class PystencilsToPymbolicMapper(SympyToPymbolicMapper):
-    def map_Assignment(self, expr):  # noqa
-        lhs = self.rec(expr.lhs)
-        rhs = self.rec(expr.rhs)
-        return PsAssignment(lhs, rhs)
-
-    def map_BasicType(self, expr):
-        width = expr.numpy_dtype.itemsize * 8
-        const = expr.const
-        if expr.is_float():
-            return PsIeeeFloatType(width, const)
-        elif expr.is_uint():
-            return PsUnsignedIntegerType(width, const)
-        elif expr.is_int():
-            return PsSignedIntegerType(width, const)
-        else:
-            raise (NotImplementedError, "Not supported dtype")
-
-    def map_FieldShapeSymbol(self, expr):
-        dtype = self.rec(expr.dtype)
-        return PsTypedVariable(expr.name, dtype)
-
-    def map_TypedSymbol(self, expr):
-        dtype = self.rec(expr.dtype)
-        return PsTypedVariable(expr.name, dtype)
-
-    def map_Access(self, expr):
-        name = expr.field.name
-        shape = tuple([self.rec(s) for s in expr.field.shape])
-        strides = tuple([self.rec(s) for s in expr.field.strides])
-        dtype = self.rec(expr.dtype)
-
-        array = PsLinearizedArray(name, shape, strides, dtype)
-
-        ptr = PsArrayBasePointer(expr.name, array)
-        index = sum(
-            [ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)]
-        )
-        index = self.rec(index)
-
-        return PsSymbolExpr(PsArrayAccess(ptr, index))
-- 
GitLab