From 1fadc0ddb2afecb23b7af84582db0dd1ddc482d4 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sat, 3 Feb 2024 13:41:03 +0100
Subject: [PATCH] basic JIT integration

---
 doc/sphinx/nbackend/index.rst                 |  2 +-
 doc/sphinx/nbackend/jit.rst                   |  6 +++
 src/pystencils/cpu/cpujit.py                  |  5 +--
 src/pystencils/nbackend/ast/kernelfunction.py | 10 ++++-
 src/pystencils/nbackend/jit/__init__.py       | 43 ++++++++++++++++++
 src/pystencils/nbackend/jit/jit.py            | 45 +++++++++++++++++++
 .../nbackend/kernelcreation/config.py         | 19 ++++++++
 .../nbackend/kernelcreation/kernelcreation.py |  4 +-
 .../kernelcreation/test_domain_kernels.py     |  6 +--
 .../kernelcreation/test_index_kernels.py      |  3 +-
 10 files changed, 131 insertions(+), 12 deletions(-)
 create mode 100644 doc/sphinx/nbackend/jit.rst
 create mode 100644 src/pystencils/nbackend/jit/jit.py

diff --git a/doc/sphinx/nbackend/index.rst b/doc/sphinx/nbackend/index.rst
index 0dfe29f80..752fa9ccb 100644
--- a/doc/sphinx/nbackend/index.rst
+++ b/doc/sphinx/nbackend/index.rst
@@ -16,4 +16,4 @@ all code generation functionality currently implemented in *pystencils* version
     arrays
     ast
     kernelcreation
-
+    jit
diff --git a/doc/sphinx/nbackend/jit.rst b/doc/sphinx/nbackend/jit.rst
new file mode 100644
index 000000000..71541527f
--- /dev/null
+++ b/doc/sphinx/nbackend/jit.rst
@@ -0,0 +1,6 @@
+************************
+Just-In-Time Compilation
+************************
+
+.. automodule:: pystencils.nbackend.jit
+    :members:
diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py
index b3db1c096..deea37c95 100644
--- a/src/pystencils/cpu/cpujit.py
+++ b/src/pystencils/cpu/cpujit.py
@@ -69,9 +69,6 @@ from pystencils.kernel_wrapper import KernelWrapper
 from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess
 from pystencils.utils import atomic_file_write, recursive_dict_update
 
-from ..nbackend.ast import PsKernelFunction
-from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule
-
 
 def make_python_function(kernel_function_node, custom_backend=None):
     """
@@ -622,7 +619,9 @@ def compile_and_load(ast, custom_backend=None):
     compiler_config = get_compiler_config()
     function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''
 
+    from ..nbackend.ast import PsKernelFunction
     if isinstance(ast, PsKernelFunction):
+        from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule
         code = PsKernelExtensioNModule()
     else:
         code = ExtensionModuleCode(custom_backend=custom_backend)
diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py
index e3c91d83c..d2e66de87 100644
--- a/src/pystencils/nbackend/ast/kernelfunction.py
+++ b/src/pystencils/nbackend/ast/kernelfunction.py
@@ -1,14 +1,18 @@
 from __future__ import annotations
 
+from typing import Callable
 from dataclasses import dataclass
 
 from pymbolic.mapper.dependency import DependencyMapper
 
 from .nodes import PsAstNode, PsBlock, failing_cast
+
 from ..constraints import PsKernelConstraint
 from ..typed_expressions import PsTypedVariable
 from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
+from ..jit import JitBase, no_jit
 from ..exceptions import PsInternalCompilerError
+
 from ...enums import Target
 
 
@@ -64,10 +68,11 @@ class PsKernelFunction(PsAstNode):
 
     __match_args__ = ("body",)
 
-    def __init__(self, body: PsBlock, target: Target, name: str = "kernel"):
+    def __init__(self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit):
         self._body: PsBlock = body
         self._target = target
         self._name = name
+        self._jit = jit
 
         self._constraints: list[PsKernelConstraint] = []
 
@@ -133,3 +138,6 @@ class PsKernelFunction(PsAstNode):
         #   To Do: Headers from target/instruction set/...
         from .collectors import collect_required_headers
         return collect_required_headers(self)
+    
+    def compile(self) -> Callable[..., None]:
+        return self._jit.compile(self)
diff --git a/src/pystencils/nbackend/jit/__init__.py b/src/pystencils/nbackend/jit/__init__.py
index e69de29bb..6953f8daf 100644
--- a/src/pystencils/nbackend/jit/__init__.py
+++ b/src/pystencils/nbackend/jit/__init__.py
@@ -0,0 +1,43 @@
+"""
+JIT compilation in the ``nbackend`` is managed by subclasses of `JitBase`.
+A JIT compiler may freely be created and configured by the user.
+It can then be passed to `create_kernel` using the ``jit`` argument of
+`CreateKernelConfig`, in which case it is hooked into the `PsKernelFunction.compile` method
+of the generated kernel function::
+
+    my_jit = MyJit()
+    kernel = create_kernel(ast, CreateKernelConfig(jit=my_jit))
+    func = kernel.compile()
+
+Otherwise, a JIT compiler may also be created free-standing, with the same effect::
+
+    my_jit = MyJit()
+    kernel = create_kernel(ast)
+    func = my_jit.compile(kernel)
+
+Currently, only wrappers around the legacy JIT compilers are available.
+
+Legacy Just-In-Time Compilation
+-------------------------------
+
+Historically, pystencils provides two main pathways for just-in-time compilation:
+The ``cpu.cpujit`` module for CPU kernels, and the ``gpu.gpujit`` module for device kernels.
+Both are available here through `LegacyCpuJit` and `LegacyGpuJit`.
+
+"""
+
+from .jit import JitBase, NoJit, LegacyCpuJit, LegacyGpuJit
+
+no_jit = NoJit()
+legacy_cpu = LegacyCpuJit()
+legacy_gpu = LegacyGpuJit()
+
+__all__ = [
+    "JitBase",
+    "LegacyCpuJit",
+    "legacy_cpu",
+    "NoJit",
+    "no_jit",
+    "LegacyGpuJit",
+    "legacy_gpu",
+]
diff --git a/src/pystencils/nbackend/jit/jit.py b/src/pystencils/nbackend/jit/jit.py
new file mode 100644
index 000000000..4e9ae46b6
--- /dev/null
+++ b/src/pystencils/nbackend/jit/jit.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+from typing import Callable, TYPE_CHECKING
+from abc import ABC, abstractmethod
+
+if TYPE_CHECKING:
+    from ..ast import PsKernelFunction
+
+
+class JitError(Exception):
+    """Indicates an error during just-in-time compilation"""
+
+
+class JitBase(ABC):
+    """Base class for just-in-time compilation interfaces implemented in pystencils."""
+
+    @abstractmethod
+    def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
+        """Compile a kernel function and return a callable object which invokes the kernel."""
+
+
+class NoJit(JitBase):
+    """Not a JIT compiler: Used to explicitly disable JIT compilation on an AST."""
+
+    def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
+        raise JitError(
+            "Just-in-time compilation of this kernel was explicitly disabled."
+        )
+
+
+class LegacyCpuJit(JitBase):
+    """Wrapper around ``pystencils.cpu.cpujit``"""
+
+    def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
+        from ...cpu.cpujit import compile_and_load
+
+        return compile_and_load(kernel)
+
+
+class LegacyGpuJit(JitBase):
+    """Wrapper around ``pystencils.gpu.gpujit``"""
+
+    def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
+        from ...gpu.gpujit import make_python_function
+
+        return make_python_function(kernel)
diff --git a/src/pystencils/nbackend/kernelcreation/config.py b/src/pystencils/nbackend/kernelcreation/config.py
index 53f4d95cc..608a818e8 100644
--- a/src/pystencils/nbackend/kernelcreation/config.py
+++ b/src/pystencils/nbackend/kernelcreation/config.py
@@ -4,6 +4,7 @@ from dataclasses import dataclass
 from ...enums import Target
 from ...field import Field, FieldType
 
+from ..jit import JitBase
 from ..exceptions import PsOptionsError
 from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType
 
@@ -20,6 +21,13 @@ class CreateKernelConfig:
     TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ...
     """
 
+    jit: JitBase | None = None
+    """Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment.
+    
+    If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter.
+    To explicitly disable JIT compilation, pass `nbackend.jit.no_jit`.
+    """
+
     function_name: str = "kernel"
     """Name of the generated function"""
 
@@ -63,6 +71,7 @@ class CreateKernelConfig:
     """
 
     def __post_init__(self):
+        #   Check iteration space argument consistency
         if (
             int(self.iteration_slice is not None)
             + int(self.ghost_layers is not None)
@@ -74,6 +83,7 @@ class CreateKernelConfig:
                 "at most one of them may be set."
             )
 
+        #   Check index field
         if (
             self.index_field is not None
             and self.index_field.field_type != FieldType.INDEXED
@@ -81,3 +91,12 @@ class CreateKernelConfig:
             raise PsOptionsError(
                 "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
             )
+        
+        #   Infer JIT
+        if self.jit is None:
+            match self.target:
+                case Target.CPU:
+                    from ..jit import legacy_cpu
+                    self.jit = legacy_cpu
+                case _:
+                    raise NotImplementedError(f"No default JIT compiler implemented yet for target {self.target}")
diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
index 390058587..8abf9c7ef 100644
--- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py
+++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py
@@ -19,6 +19,7 @@ def create_kernel(
     assignments: AssignmentCollection,
     config: CreateKernelConfig = CreateKernelConfig(),
 ):
+    """Create a kernel AST from an assignment collection."""
     ctx = KernelCreationContext(config)
 
     analysis = KernelAnalysis(ctx)
@@ -57,7 +58,8 @@ def create_kernel(
     #     - Loop Splitting, Tiling, Blocking
     kernel_ast = platform.optimize(kernel_ast)
 
-    function = PsKernelFunction(kernel_ast, config.target, name=config.function_name)
+    assert config.jit is not None
+    function = PsKernelFunction(kernel_ast, config.target, name=config.function_name, jit=config.jit)
     function.add_constraints(*ctx.constraints)
 
     return function
diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py
index 1f75df856..5a3649d68 100644
--- a/tests/nbackend/kernelcreation/test_domain_kernels.py
+++ b/tests/nbackend/kernelcreation/test_domain_kernels.py
@@ -7,7 +7,6 @@ from pystencils import fields, Field, AssignmentCollection
 from pystencils.assignment import assignment_from_stencil
 
 from pystencils.nbackend.kernelcreation import create_kernel
-from pystencils.cpu.cpujit import compile_and_load
 
 def test_filter_kernel():
     weight = sp.Symbol("weight")
@@ -22,8 +21,7 @@ def test_filter_kernel():
     asms = AssignmentCollection([asm])
 
     ast = create_kernel(asms)
-
-    kernel = compile_and_load(ast)
+    kernel = ast.compile()
 
     src_arr = np.ones((42, 42))
     dst_arr = np.zeros_like(src_arr)
@@ -54,7 +52,7 @@ def test_filter_kernel_fixedsize():
     asms = AssignmentCollection([asm])
 
     ast = create_kernel(asms)
-    kernel = compile_and_load(ast)
+    kernel = ast.compile()
 
     kernel(src=src_arr, dst=dst_arr, weight=2.0)
 
diff --git a/tests/nbackend/kernelcreation/test_index_kernels.py b/tests/nbackend/kernelcreation/test_index_kernels.py
index e8a32b6b4..4c4c6b8b8 100644
--- a/tests/nbackend/kernelcreation/test_index_kernels.py
+++ b/tests/nbackend/kernelcreation/test_index_kernels.py
@@ -5,7 +5,6 @@ import numpy as np
 
 from pystencils import Assignment, Field, FieldType, AssignmentCollection
 from pystencils.nbackend.kernelcreation import create_kernel, CreateKernelConfig
-from pystencils.cpu.cpujit import compile_and_load
 
 def test_indexed_kernel():
     arr = np.zeros((3, 4))
@@ -23,7 +22,7 @@ def test_indexed_kernel():
 
     options = CreateKernelConfig(index_field=index_field)
     ast = create_kernel(update_rule, options)
-    kernel = compile_and_load(ast)
+    kernel = ast.compile()
 
     kernel(f=arr, index=index_arr)
 
-- 
GitLab