diff --git a/doc/sphinx/nbackend/index.rst b/doc/sphinx/nbackend/index.rst index 0dfe29f80ab51375c7a7e5629dfa139cf5a7cc74..752fa9ccb6a3a1413ec6c6aad2be9752bbd19893 100644 --- a/doc/sphinx/nbackend/index.rst +++ b/doc/sphinx/nbackend/index.rst @@ -16,4 +16,4 @@ all code generation functionality currently implemented in *pystencils* version arrays ast kernelcreation - + jit diff --git a/doc/sphinx/nbackend/jit.rst b/doc/sphinx/nbackend/jit.rst new file mode 100644 index 0000000000000000000000000000000000000000..71541527f7389eece9a7ef94826bbf1e566cbad4 --- /dev/null +++ b/doc/sphinx/nbackend/jit.rst @@ -0,0 +1,6 @@ +************************ +Just-In-Time Compilation +************************ + +.. automodule:: pystencils.nbackend.jit + :members: diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py index b3db1c0969185eefe3d20afb5aae5a68eca684ce..deea37c9536fea1b388273f41d7c9f8ac02c1f5e 100644 --- a/src/pystencils/cpu/cpujit.py +++ b/src/pystencils/cpu/cpujit.py @@ -69,9 +69,6 @@ from pystencils.kernel_wrapper import KernelWrapper from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess from pystencils.utils import atomic_file_write, recursive_dict_update -from ..nbackend.ast import PsKernelFunction -from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule - def make_python_function(kernel_function_node, custom_backend=None): """ @@ -622,7 +619,9 @@ def compile_and_load(ast, custom_backend=None): compiler_config = get_compiler_config() function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else '' + from ..nbackend.ast import PsKernelFunction if isinstance(ast, PsKernelFunction): + from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule code = PsKernelExtensioNModule() else: code = ExtensionModuleCode(custom_backend=custom_backend) diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py index e3c91d83c28b15749d4365d8ed5c4125746b74a9..d2e66de87afda88f41c2cf17f72579b5036a846a 100644 --- a/src/pystencils/nbackend/ast/kernelfunction.py +++ b/src/pystencils/nbackend/ast/kernelfunction.py @@ -1,14 +1,18 @@ from __future__ import annotations +from typing import Callable from dataclasses import dataclass from pymbolic.mapper.dependency import DependencyMapper from .nodes import PsAstNode, PsBlock, failing_cast + from ..constraints import PsKernelConstraint from ..typed_expressions import PsTypedVariable from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar +from ..jit import JitBase, no_jit from ..exceptions import PsInternalCompilerError + from ...enums import Target @@ -64,10 +68,11 @@ class PsKernelFunction(PsAstNode): __match_args__ = ("body",) - def __init__(self, body: PsBlock, target: Target, name: str = "kernel"): + def __init__(self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit): self._body: PsBlock = body self._target = target self._name = name + self._jit = jit self._constraints: list[PsKernelConstraint] = [] @@ -133,3 +138,6 @@ class PsKernelFunction(PsAstNode): # To Do: Headers from target/instruction set/... from .collectors import collect_required_headers return collect_required_headers(self) + + def compile(self) -> Callable[..., None]: + return self._jit.compile(self) diff --git a/src/pystencils/nbackend/jit/__init__.py b/src/pystencils/nbackend/jit/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..6953f8daffe1d7cfb12e10e9e22bda7b4838f038 100644 --- a/src/pystencils/nbackend/jit/__init__.py +++ b/src/pystencils/nbackend/jit/__init__.py @@ -0,0 +1,43 @@ +""" +JIT compilation in the ``nbackend`` is managed by subclasses of `JitBase`. +A JIT compiler may freely be created and configured by the user. +It can then be passed to `create_kernel` using the ``jit`` argument of +`CreateKernelConfig`, in which case it is hooked into the `PsKernelFunction.compile` method +of the generated kernel function:: + + my_jit = MyJit() + kernel = create_kernel(ast, CreateKernelConfig(jit=my_jit)) + func = kernel.compile() + +Otherwise, a JIT compiler may also be created free-standing, with the same effect:: + + my_jit = MyJit() + kernel = create_kernel(ast) + func = my_jit.compile(kernel) + +Currently, only wrappers around the legacy JIT compilers are available. + +Legacy Just-In-Time Compilation +------------------------------- + +Historically, pystencils provides two main pathways for just-in-time compilation: +The ``cpu.cpujit`` module for CPU kernels, and the ``gpu.gpujit`` module for device kernels. +Both are available here through `LegacyCpuJit` and `LegacyGpuJit`. + +""" + +from .jit import JitBase, NoJit, LegacyCpuJit, LegacyGpuJit + +no_jit = NoJit() +legacy_cpu = LegacyCpuJit() +legacy_gpu = LegacyGpuJit() + +__all__ = [ + "JitBase", + "LegacyCpuJit", + "legacy_cpu", + "NoJit", + "no_jit", + "LegacyGpuJit", + "legacy_gpu", +] diff --git a/src/pystencils/nbackend/jit/jit.py b/src/pystencils/nbackend/jit/jit.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9ae46b61f7018485976f3cfebcbce6e46641c3 --- /dev/null +++ b/src/pystencils/nbackend/jit/jit.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from typing import Callable, TYPE_CHECKING +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from ..ast import PsKernelFunction + + +class JitError(Exception): + """Indicates an error during just-in-time compilation""" + + +class JitBase(ABC): + """Base class for just-in-time compilation interfaces implemented in pystencils.""" + + @abstractmethod + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + """Compile a kernel function and return a callable object which invokes the kernel.""" + + +class NoJit(JitBase): + """Not a JIT compiler: Used to explicitly disable JIT compilation on an AST.""" + + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + raise JitError( + "Just-in-time compilation of this kernel was explicitly disabled." + ) + + +class LegacyCpuJit(JitBase): + """Wrapper around ``pystencils.cpu.cpujit``""" + + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + from ...cpu.cpujit import compile_and_load + + return compile_and_load(kernel) + + +class LegacyGpuJit(JitBase): + """Wrapper around ``pystencils.gpu.gpujit``""" + + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + from ...gpu.gpujit import make_python_function + + return make_python_function(kernel) diff --git a/src/pystencils/nbackend/kernelcreation/config.py b/src/pystencils/nbackend/kernelcreation/config.py index 53f4d95cc237cfb1b5585ce00719abbee5c1881c..608a818e83c95ff1c39864aedab87ec4cfeb11a8 100644 --- a/src/pystencils/nbackend/kernelcreation/config.py +++ b/src/pystencils/nbackend/kernelcreation/config.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from ...enums import Target from ...field import Field, FieldType +from ..jit import JitBase from ..exceptions import PsOptionsError from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType @@ -20,6 +21,13 @@ class CreateKernelConfig: TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ... """ + jit: JitBase | None = None + """Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment. + + If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter. + To explicitly disable JIT compilation, pass `nbackend.jit.no_jit`. + """ + function_name: str = "kernel" """Name of the generated function""" @@ -63,6 +71,7 @@ class CreateKernelConfig: """ def __post_init__(self): + # Check iteration space argument consistency if ( int(self.iteration_slice is not None) + int(self.ghost_layers is not None) @@ -74,6 +83,7 @@ class CreateKernelConfig: "at most one of them may be set." ) + # Check index field if ( self.index_field is not None and self.index_field.field_type != FieldType.INDEXED @@ -81,3 +91,12 @@ class CreateKernelConfig: raise PsOptionsError( "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" ) + + # Infer JIT + if self.jit is None: + match self.target: + case Target.CPU: + from ..jit import legacy_cpu + self.jit = legacy_cpu + case _: + raise NotImplementedError(f"No default JIT compiler implemented yet for target {self.target}") diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index 3900585872fab80883c1c18bae00c89b2b706299..8abf9c7ef0e81601ef15ee76ce82602de9557f63 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -19,6 +19,7 @@ def create_kernel( assignments: AssignmentCollection, config: CreateKernelConfig = CreateKernelConfig(), ): + """Create a kernel AST from an assignment collection.""" ctx = KernelCreationContext(config) analysis = KernelAnalysis(ctx) @@ -57,7 +58,8 @@ def create_kernel( # - Loop Splitting, Tiling, Blocking kernel_ast = platform.optimize(kernel_ast) - function = PsKernelFunction(kernel_ast, config.target, name=config.function_name) + assert config.jit is not None + function = PsKernelFunction(kernel_ast, config.target, name=config.function_name, jit=config.jit) function.add_constraints(*ctx.constraints) return function diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 1f75df8568d2260bbd37f056c9211a2480dfb163..5a3649d68ad1b2f225501990dcb6d21af0bd9956 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -7,7 +7,6 @@ from pystencils import fields, Field, AssignmentCollection from pystencils.assignment import assignment_from_stencil from pystencils.nbackend.kernelcreation import create_kernel -from pystencils.cpu.cpujit import compile_and_load def test_filter_kernel(): weight = sp.Symbol("weight") @@ -22,8 +21,7 @@ def test_filter_kernel(): asms = AssignmentCollection([asm]) ast = create_kernel(asms) - - kernel = compile_and_load(ast) + kernel = ast.compile() src_arr = np.ones((42, 42)) dst_arr = np.zeros_like(src_arr) @@ -54,7 +52,7 @@ def test_filter_kernel_fixedsize(): asms = AssignmentCollection([asm]) ast = create_kernel(asms) - kernel = compile_and_load(ast) + kernel = ast.compile() kernel(src=src_arr, dst=dst_arr, weight=2.0) diff --git a/tests/nbackend/kernelcreation/test_index_kernels.py b/tests/nbackend/kernelcreation/test_index_kernels.py index e8a32b6b48ebda1a12cf7e4261178ef58e755eee..4c4c6b8b816114964ae0ba7f3c77e39add292f7e 100644 --- a/tests/nbackend/kernelcreation/test_index_kernels.py +++ b/tests/nbackend/kernelcreation/test_index_kernels.py @@ -5,7 +5,6 @@ import numpy as np from pystencils import Assignment, Field, FieldType, AssignmentCollection from pystencils.nbackend.kernelcreation import create_kernel, CreateKernelConfig -from pystencils.cpu.cpujit import compile_and_load def test_indexed_kernel(): arr = np.zeros((3, 4)) @@ -23,7 +22,7 @@ def test_indexed_kernel(): options = CreateKernelConfig(index_field=index_field) ast = create_kernel(update_rule, options) - kernel = compile_and_load(ast) + kernel = ast.compile() kernel(f=arr, index=index_arr)