From 8490e6cb918cdfe00261fcdd4b308da0e566f0d7 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 19 Feb 2024 12:47:35 +0100
Subject: [PATCH] move CreateKernelConfig to kernelcreation.py

---
 .../backend/kernelcreation/config.py          | 102 -------------
 .../backend/kernelcreation/context.py         |  18 +--
 .../backend/kernelcreation/defaults.py        |   5 +-
 .../backend/kernelcreation/iteration_space.py |  21 +--
 .../backend/kernelcreation/typification.py    |   4 +-
 src/pystencils/kernelcreation.py              | 135 +++++++++++++++++-
 6 files changed, 157 insertions(+), 128 deletions(-)
 delete mode 100644 src/pystencils/backend/kernelcreation/config.py

diff --git a/src/pystencils/backend/kernelcreation/config.py b/src/pystencils/backend/kernelcreation/config.py
deleted file mode 100644
index 608a818e8..000000000
--- a/src/pystencils/backend/kernelcreation/config.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from typing import Sequence
-from dataclasses import dataclass
-
-from ...enums import Target
-from ...field import Field, FieldType
-
-from ..jit import JitBase
-from ..exceptions import PsOptionsError
-from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType
-
-from .defaults import Sympy as SpDefaults
-
-
-@dataclass
-class CreateKernelConfig:
-    """Options for create_kernel."""
-
-    target: Target = Target.CPU
-    """The code generation target.
-    
-    TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ...
-    """
-
-    jit: JitBase | None = None
-    """Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment.
-    
-    If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter.
-    To explicitly disable JIT compilation, pass `nbackend.jit.no_jit`.
-    """
-
-    function_name: str = "kernel"
-    """Name of the generated function"""
-
-    ghost_layers: None | int | Sequence[int | tuple[int, int]] = None
-    """Specifies the number of ghost layers of the iteration region.
-    
-    Options:
-     - `None`: Required ghost layers are inferred from field accesses
-     - `int`:  A uniform number of ghost layers in each spatial coordinate is applied
-     - `Sequence[int, tuple[int, int]]`: Ghost layers are specified for each spatial coordinate.
-        In each coordinate, a single integer specifies the ghost layers at both the lower and upper iteration limit,
-        while a pair of integers specifies the lower and upper ghost layers separately.
-
-    When manually specifying ghost layers, it is the user's responsibility to avoid out-of-bounds memory accesses.
-    If `ghost_layers=None` is specified, the iteration region may otherwise be set using the `iteration_slice` option.
-    """
-
-    iteration_slice: None | tuple[slice, ...] = None
-    """Specifies the kernel's iteration slice.
-    
-    `iteration_slice` may only be set if `ghost_layers = None`.
-    If it is set, a slice must be specified for each spatial coordinate.
-    TODO: Specification of valid slices and their behaviour
-    """
-
-    index_field: Field | None = None
-    """Index field for a sparse kernel.
-    
-    If this option is set, a sparse kernel with the given field as index field will be generated.
-    """
-
-    """Data Types"""
-
-    index_dtype: PsIntegerType = SpDefaults.index_dtype
-    """Data type used for all index calculations."""
-
-    default_dtype: PsNumericType = PsIeeeFloatType(64)
-    """Default numeric data type.
-    
-    This data type will be applied to all untyped symbols.
-    """
-
-    def __post_init__(self):
-        #   Check iteration space argument consistency
-        if (
-            int(self.iteration_slice is not None)
-            + int(self.ghost_layers is not None)
-            + int(self.index_field is not None)
-            > 1
-        ):
-            raise PsOptionsError(
-                "Parameters `iteration_slice`, `ghost_layers` and 'index_field` are mutually exclusive; "
-                "at most one of them may be set."
-            )
-
-        #   Check index field
-        if (
-            self.index_field is not None
-            and self.index_field.field_type != FieldType.INDEXED
-        ):
-            raise PsOptionsError(
-                "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
-            )
-        
-        #   Infer JIT
-        if self.jit is None:
-            match self.target:
-                case Target.CPU:
-                    from ..jit import legacy_cpu
-                    self.jit = legacy_cpu
-                case _:
-                    raise NotImplementedError(f"No default JIT compiler implemented yet for target {self.target}")
diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 2080efed3..0982e8455 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -3,13 +3,12 @@ from __future__ import annotations
 from ...field import Field, FieldType
 from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType
 from ..arrays import PsLinearizedArray
-from ..types import PsIntegerType
+from ..types import PsIntegerType, PsNumericType
 from ..types.quick import make_type
 from ..constraints import PsKernelConstraint
 from ..exceptions import PsInternalCompilerError, KernelConstraintsError
 
-
-from .config import CreateKernelConfig
+from .defaults import Pymbolic as PbDefaults
 from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace
 
 
@@ -44,8 +43,11 @@ class KernelCreationContext:
     or full iteration space.
     """
 
-    def __init__(self, options: CreateKernelConfig):
-        self._options = options
+    def __init__(self,
+                 default_dtype: PsNumericType = PbDefaults.numeric_dtype,
+                 index_dtype: PsIntegerType = PbDefaults.index_dtype):
+        self._default_dtype = default_dtype
+        self._index_dtype = index_dtype
         self._arrays: dict[Field, PsLinearizedArray] = dict()
         self._constraints: list[PsKernelConstraint] = []
 
@@ -53,12 +55,12 @@ class KernelCreationContext:
         self._ispace: IterationSpace | None = None
 
     @property
-    def options(self) -> CreateKernelConfig:
-        return self._options
+    def default_dtype(self) -> PsNumericType:
+        return self._default_dtype
 
     @property
     def index_dtype(self) -> PsIntegerType:
-        return self._options.index_dtype
+        return self._index_dtype
 
     def add_constraints(self, *constraints: PsKernelConstraint):
         self._constraints += constraints
diff --git a/src/pystencils/backend/kernelcreation/defaults.py b/src/pystencils/backend/kernelcreation/defaults.py
index b1822dc77..c52f6c254 100644
--- a/src/pystencils/backend/kernelcreation/defaults.py
+++ b/src/pystencils/backend/kernelcreation/defaults.py
@@ -17,7 +17,7 @@ A possibly incomplete list of symbols and types that need to be defined:
 """
 
 from typing import TypeVar, Generic, Callable
-from ..types import PsAbstractType, PsSignedIntegerType, PsStructType
+from ..types import PsAbstractType, PsIeeeFloatType, PsSignedIntegerType, PsStructType
 from ..typed_expressions import PsTypedVariable
 
 from pystencils.sympyextensions.typed_sympy import TypedSymbol
@@ -27,6 +27,9 @@ SymbolT = TypeVar("SymbolT")
 
 class PsDefaults(Generic[SymbolT]):
     def __init__(self, symcreate: Callable[[str, PsAbstractType], SymbolT]):
+        self.numeric_dtype = PsIeeeFloatType(64)
+        """Default data type for numerical computations"""
+        
         self.index_dtype = PsSignedIntegerType(64)
         """Default data type for indices."""
 
diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index 0e5828ef1..147a4749e 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -224,7 +224,7 @@ def get_archetype_field(
 
 
 def create_sparse_iteration_space(
-    ctx: KernelCreationContext, assignments: AssignmentCollection
+    ctx: KernelCreationContext, assignments: AssignmentCollection, index_field: Field | None = None
 ) -> IterationSpace:
     #   All domain and custom fields must have the same spatial dimensions
     #   TODO: Must all domain fields have the same shape?
@@ -242,9 +242,8 @@ def create_sparse_iteration_space(
     ]
 
     #   Determine index field
-    if ctx.options.index_field is not None:
-        idx_field = ctx.options.index_field
-        idx_arr = ctx.get_array(idx_field)
+    if index_field is not None:
+        idx_arr = ctx.get_array(index_field)
         idx_struct_type: PsStructType = failing_cast(PsStructType, idx_arr.element_type)
 
         for coord in coord_members:
@@ -269,10 +268,16 @@ def create_sparse_iteration_space(
 
 
 def create_full_iteration_space(
-    ctx: KernelCreationContext, assignments: AssignmentCollection
+    ctx: KernelCreationContext,
+    assignments: AssignmentCollection,
+    ghost_layers: None | int | Sequence[int | tuple[int, int]] = None,
+    iteration_slice: None | tuple[slice, ...] = None
 ) -> IterationSpace:
     assert not ctx.fields.index_fields
 
+    if (ghost_layers is not None) and (iteration_slice is not None):
+        raise ValueError("At most one of `ghost_layers` and `iteration_slice` may be specified.")
+
     #   Collect all relative accesses into domain fields
     def access_filter(acc: Field.Access):
         return acc.field.field_type in (
@@ -305,11 +310,11 @@ def create_full_iteration_space(
     # Otherwise, if an iteration slice was specified, use that
     # Otherwise, use the inferred ghost layers
 
-    if ctx.options.ghost_layers is not None:
+    if ghost_layers is not None:
         return FullIterationSpace.create_with_ghost_layers(
-            ctx, archetype_field, ctx.options.ghost_layers
+            ctx, archetype_field, ghost_layers
         )
-    elif ctx.options.iteration_slice is not None:
+    elif iteration_slice is not None:
         raise PsInternalCompilerError("Iteration slices not supported yet")
     else:
         return FullIterationSpace.create_with_ghost_layers(
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 37ed8e8e1..7c953c5c4 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -158,7 +158,7 @@ class Typifier(Mapper):
         return var
 
     def map_variable(self, var: pb.Variable, tc: TypeContext) -> PsTypedVariable:
-        dtype = self._ctx.options.default_dtype
+        dtype = self._ctx.default_dtype
         typed_var = PsTypedVariable(var.name, dtype)
         self._apply_target_type(typed_var, dtype, tc)
         return typed_var
@@ -175,7 +175,7 @@ class Typifier(Mapper):
     def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess:
         self._apply_target_type(access, access.dtype, tc)
         index = self.rec(
-            access.index_tuple[0], TypeContext(self._ctx.options.index_dtype)
+            access.index_tuple[0], TypeContext(self._ctx.index_dtype)
         )
         return PsArrayAccess(access.base_ptr, index)
 
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index fb4d6abfb..e12381068 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -1,5 +1,22 @@
+from typing import Sequence
+from dataclasses import dataclass
+
+from .enums import Target
+from .field import Field, FieldType
+
+from .backend.jit import JitBase
+from .backend.exceptions import PsOptionsError
+from .backend.types import PsIntegerType, PsNumericType, PsIeeeFloatType
+
+from .backend.kernelcreation.defaults import Sympy as SpDefaults
+
 from .backend.ast import PsKernelFunction
-from .backend.kernelcreation import KernelCreationContext, KernelAnalysis, FreezeExpressions, Typifier
+from .backend.kernelcreation import (
+    KernelCreationContext,
+    KernelAnalysis,
+    FreezeExpressions,
+    Typifier,
+)
 from .backend.kernelcreation.iteration_space import (
     create_sparse_iteration_space,
     create_full_iteration_space,
@@ -7,24 +24,126 @@ from .backend.kernelcreation.iteration_space import (
 from .backend.kernelcreation.transformations import EraseAnonymousStructTypes
 
 from .enums import Target
-from .config import CreateKernelConfig
 from .sympyextensions import AssignmentCollection
 
 
+@dataclass
+class CreateKernelConfig:
+    """Options for create_kernel."""
+
+    target: Target = Target.CPU
+    """The code generation target.
+    
+    TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ...
+    """
+
+    jit: JitBase | None = None
+    """Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment.
+    
+    If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter.
+    To explicitly disable JIT compilation, pass `nbackend.jit.no_jit`.
+    """
+
+    function_name: str = "kernel"
+    """Name of the generated function"""
+
+    ghost_layers: None | int | Sequence[int | tuple[int, int]] = None
+    """Specifies the number of ghost layers of the iteration region.
+    
+    Options:
+     - `None`: Required ghost layers are inferred from field accesses
+     - `int`:  A uniform number of ghost layers in each spatial coordinate is applied
+     - `Sequence[int, tuple[int, int]]`: Ghost layers are specified for each spatial coordinate.
+        In each coordinate, a single integer specifies the ghost layers at both the lower and upper iteration limit,
+        while a pair of integers specifies the lower and upper ghost layers separately.
+
+    When manually specifying ghost layers, it is the user's responsibility to avoid out-of-bounds memory accesses.
+    If `ghost_layers=None` is specified, the iteration region may otherwise be set using the `iteration_slice` option.
+    """
+
+    iteration_slice: None | tuple[slice, ...] = None
+    """Specifies the kernel's iteration slice.
+    
+    `iteration_slice` may only be set if `ghost_layers = None`.
+    If it is set, a slice must be specified for each spatial coordinate.
+    TODO: Specification of valid slices and their behaviour
+    """
+
+    index_field: Field | None = None
+    """Index field for a sparse kernel.
+    
+    If this option is set, a sparse kernel with the given field as index field will be generated.
+    """
+
+    """Data Types"""
+
+    index_dtype: PsIntegerType = SpDefaults.index_dtype
+    """Data type used for all index calculations."""
+
+    default_dtype: PsNumericType = PsIeeeFloatType(64)
+    """Default numeric data type.
+    
+    This data type will be applied to all untyped symbols.
+    """
+
+    def __post_init__(self):
+        #   Check iteration space argument consistency
+        if (
+            int(self.iteration_slice is not None)
+            + int(self.ghost_layers is not None)
+            + int(self.index_field is not None)
+            > 1
+        ):
+            raise PsOptionsError(
+                "Parameters `iteration_slice`, `ghost_layers` and 'index_field` are mutually exclusive; "
+                "at most one of them may be set."
+            )
+
+        #   Check index field
+        if (
+            self.index_field is not None
+            and self.index_field.field_type != FieldType.INDEXED
+        ):
+            raise PsOptionsError(
+                "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
+            )
+
+        #   Infer JIT
+        if self.jit is None:
+            match self.target:
+                case Target.CPU:
+                    from .backend.jit import legacy_cpu
+
+                    self.jit = legacy_cpu
+                case _:
+                    raise NotImplementedError(
+                        f"No default JIT compiler implemented yet for target {self.target}"
+                    )
+
+
 def create_kernel(
     assignments: AssignmentCollection,
     config: CreateKernelConfig = CreateKernelConfig(),
 ):
     """Create a kernel AST from an assignment collection."""
-    ctx = KernelCreationContext(config)
+    ctx = KernelCreationContext(
+        default_dtype=config.default_dtype, index_dtype=config.index_dtype
+    )
 
     analysis = KernelAnalysis(ctx)
     analysis(assignments)
 
-    if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None:
-        ispace = create_sparse_iteration_space(ctx, assignments)
+    if len(ctx.fields.index_fields) > 0 or config.index_field is not None:
+        ispace = create_sparse_iteration_space(
+            ctx, assignments, index_field=config.index_field
+        )
     else:
-        ispace = create_full_iteration_space(ctx, assignments)
+        ispace = create_full_iteration_space(
+            ctx,
+            assignments,
+            ghost_layers=config.ghost_layers,
+            iteration_slice=config.iteration_slice,
+        )
 
     ctx.set_iteration_space(ispace)
 
@@ -55,7 +174,9 @@ def create_kernel(
     kernel_ast = platform.optimize(kernel_ast)
 
     assert config.jit is not None
-    function = PsKernelFunction(kernel_ast, config.target, name=config.function_name, jit=config.jit)
+    function = PsKernelFunction(
+        kernel_ast, config.target, name=config.function_name, jit=config.jit
+    )
     function.add_constraints(*ctx.constraints)
 
     return function
-- 
GitLab