From 1fadc0ddb2afecb23b7af84582db0dd1ddc482d4 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sat, 3 Feb 2024 13:41:03 +0100 Subject: [PATCH] basic JIT integration --- doc/sphinx/nbackend/index.rst | 2 +- doc/sphinx/nbackend/jit.rst | 6 +++ src/pystencils/cpu/cpujit.py | 5 +-- src/pystencils/nbackend/ast/kernelfunction.py | 10 ++++- src/pystencils/nbackend/jit/__init__.py | 43 ++++++++++++++++++ src/pystencils/nbackend/jit/jit.py | 45 +++++++++++++++++++ .../nbackend/kernelcreation/config.py | 19 ++++++++ .../nbackend/kernelcreation/kernelcreation.py | 4 +- .../kernelcreation/test_domain_kernels.py | 6 +-- .../kernelcreation/test_index_kernels.py | 3 +- 10 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 doc/sphinx/nbackend/jit.rst create mode 100644 src/pystencils/nbackend/jit/jit.py diff --git a/doc/sphinx/nbackend/index.rst b/doc/sphinx/nbackend/index.rst index 0dfe29f80..752fa9ccb 100644 --- a/doc/sphinx/nbackend/index.rst +++ b/doc/sphinx/nbackend/index.rst @@ -16,4 +16,4 @@ all code generation functionality currently implemented in *pystencils* version arrays ast kernelcreation - + jit diff --git a/doc/sphinx/nbackend/jit.rst b/doc/sphinx/nbackend/jit.rst new file mode 100644 index 000000000..71541527f --- /dev/null +++ b/doc/sphinx/nbackend/jit.rst @@ -0,0 +1,6 @@ +************************ +Just-In-Time Compilation +************************ + +.. automodule:: pystencils.nbackend.jit + :members: diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py index b3db1c096..deea37c95 100644 --- a/src/pystencils/cpu/cpujit.py +++ b/src/pystencils/cpu/cpujit.py @@ -69,9 +69,6 @@ from pystencils.kernel_wrapper import KernelWrapper from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess from pystencils.utils import atomic_file_write, recursive_dict_update -from ..nbackend.ast import PsKernelFunction -from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule - def make_python_function(kernel_function_node, custom_backend=None): """ @@ -622,7 +619,9 @@ def compile_and_load(ast, custom_backend=None): compiler_config = get_compiler_config() function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else '' + from ..nbackend.ast import PsKernelFunction if isinstance(ast, PsKernelFunction): + from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule code = PsKernelExtensioNModule() else: code = ExtensionModuleCode(custom_backend=custom_backend) diff --git a/src/pystencils/nbackend/ast/kernelfunction.py b/src/pystencils/nbackend/ast/kernelfunction.py index e3c91d83c..d2e66de87 100644 --- a/src/pystencils/nbackend/ast/kernelfunction.py +++ b/src/pystencils/nbackend/ast/kernelfunction.py @@ -1,14 +1,18 @@ from __future__ import annotations +from typing import Callable from dataclasses import dataclass from pymbolic.mapper.dependency import DependencyMapper from .nodes import PsAstNode, PsBlock, failing_cast + from ..constraints import PsKernelConstraint from ..typed_expressions import PsTypedVariable from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar +from ..jit import JitBase, no_jit from ..exceptions import PsInternalCompilerError + from ...enums import Target @@ -64,10 +68,11 @@ class PsKernelFunction(PsAstNode): __match_args__ = ("body",) - def __init__(self, body: PsBlock, target: Target, name: str = "kernel"): + def __init__(self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit): self._body: PsBlock = body self._target = target self._name = name + self._jit = jit self._constraints: list[PsKernelConstraint] = [] @@ -133,3 +138,6 @@ class PsKernelFunction(PsAstNode): # To Do: Headers from target/instruction set/... from .collectors import collect_required_headers return collect_required_headers(self) + + def compile(self) -> Callable[..., None]: + return self._jit.compile(self) diff --git a/src/pystencils/nbackend/jit/__init__.py b/src/pystencils/nbackend/jit/__init__.py index e69de29bb..6953f8daf 100644 --- a/src/pystencils/nbackend/jit/__init__.py +++ b/src/pystencils/nbackend/jit/__init__.py @@ -0,0 +1,43 @@ +""" +JIT compilation in the ``nbackend`` is managed by subclasses of `JitBase`. +A JIT compiler may freely be created and configured by the user. +It can then be passed to `create_kernel` using the ``jit`` argument of +`CreateKernelConfig`, in which case it is hooked into the `PsKernelFunction.compile` method +of the generated kernel function:: + + my_jit = MyJit() + kernel = create_kernel(ast, CreateKernelConfig(jit=my_jit)) + func = kernel.compile() + +Otherwise, a JIT compiler may also be created free-standing, with the same effect:: + + my_jit = MyJit() + kernel = create_kernel(ast) + func = my_jit.compile(kernel) + +Currently, only wrappers around the legacy JIT compilers are available. + +Legacy Just-In-Time Compilation +------------------------------- + +Historically, pystencils provides two main pathways for just-in-time compilation: +The ``cpu.cpujit`` module for CPU kernels, and the ``gpu.gpujit`` module for device kernels. +Both are available here through `LegacyCpuJit` and `LegacyGpuJit`. + +""" + +from .jit import JitBase, NoJit, LegacyCpuJit, LegacyGpuJit + +no_jit = NoJit() +legacy_cpu = LegacyCpuJit() +legacy_gpu = LegacyGpuJit() + +__all__ = [ + "JitBase", + "LegacyCpuJit", + "legacy_cpu", + "NoJit", + "no_jit", + "LegacyGpuJit", + "legacy_gpu", +] diff --git a/src/pystencils/nbackend/jit/jit.py b/src/pystencils/nbackend/jit/jit.py new file mode 100644 index 000000000..4e9ae46b6 --- /dev/null +++ b/src/pystencils/nbackend/jit/jit.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from typing import Callable, TYPE_CHECKING +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from ..ast import PsKernelFunction + + +class JitError(Exception): + """Indicates an error during just-in-time compilation""" + + +class JitBase(ABC): + """Base class for just-in-time compilation interfaces implemented in pystencils.""" + + @abstractmethod + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + """Compile a kernel function and return a callable object which invokes the kernel.""" + + +class NoJit(JitBase): + """Not a JIT compiler: Used to explicitly disable JIT compilation on an AST.""" + + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + raise JitError( + "Just-in-time compilation of this kernel was explicitly disabled." + ) + + +class LegacyCpuJit(JitBase): + """Wrapper around ``pystencils.cpu.cpujit``""" + + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + from ...cpu.cpujit import compile_and_load + + return compile_and_load(kernel) + + +class LegacyGpuJit(JitBase): + """Wrapper around ``pystencils.gpu.gpujit``""" + + def compile(self, kernel: PsKernelFunction) -> Callable[..., None]: + from ...gpu.gpujit import make_python_function + + return make_python_function(kernel) diff --git a/src/pystencils/nbackend/kernelcreation/config.py b/src/pystencils/nbackend/kernelcreation/config.py index 53f4d95cc..608a818e8 100644 --- a/src/pystencils/nbackend/kernelcreation/config.py +++ b/src/pystencils/nbackend/kernelcreation/config.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from ...enums import Target from ...field import Field, FieldType +from ..jit import JitBase from ..exceptions import PsOptionsError from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType @@ -20,6 +21,13 @@ class CreateKernelConfig: TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ... """ + jit: JitBase | None = None + """Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment. + + If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter. + To explicitly disable JIT compilation, pass `nbackend.jit.no_jit`. + """ + function_name: str = "kernel" """Name of the generated function""" @@ -63,6 +71,7 @@ class CreateKernelConfig: """ def __post_init__(self): + # Check iteration space argument consistency if ( int(self.iteration_slice is not None) + int(self.ghost_layers is not None) @@ -74,6 +83,7 @@ class CreateKernelConfig: "at most one of them may be set." ) + # Check index field if ( self.index_field is not None and self.index_field.field_type != FieldType.INDEXED @@ -81,3 +91,12 @@ class CreateKernelConfig: raise PsOptionsError( "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" ) + + # Infer JIT + if self.jit is None: + match self.target: + case Target.CPU: + from ..jit import legacy_cpu + self.jit = legacy_cpu + case _: + raise NotImplementedError(f"No default JIT compiler implemented yet for target {self.target}") diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index 390058587..8abf9c7ef 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -19,6 +19,7 @@ def create_kernel( assignments: AssignmentCollection, config: CreateKernelConfig = CreateKernelConfig(), ): + """Create a kernel AST from an assignment collection.""" ctx = KernelCreationContext(config) analysis = KernelAnalysis(ctx) @@ -57,7 +58,8 @@ def create_kernel( # - Loop Splitting, Tiling, Blocking kernel_ast = platform.optimize(kernel_ast) - function = PsKernelFunction(kernel_ast, config.target, name=config.function_name) + assert config.jit is not None + function = PsKernelFunction(kernel_ast, config.target, name=config.function_name, jit=config.jit) function.add_constraints(*ctx.constraints) return function diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 1f75df856..5a3649d68 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -7,7 +7,6 @@ from pystencils import fields, Field, AssignmentCollection from pystencils.assignment import assignment_from_stencil from pystencils.nbackend.kernelcreation import create_kernel -from pystencils.cpu.cpujit import compile_and_load def test_filter_kernel(): weight = sp.Symbol("weight") @@ -22,8 +21,7 @@ def test_filter_kernel(): asms = AssignmentCollection([asm]) ast = create_kernel(asms) - - kernel = compile_and_load(ast) + kernel = ast.compile() src_arr = np.ones((42, 42)) dst_arr = np.zeros_like(src_arr) @@ -54,7 +52,7 @@ def test_filter_kernel_fixedsize(): asms = AssignmentCollection([asm]) ast = create_kernel(asms) - kernel = compile_and_load(ast) + kernel = ast.compile() kernel(src=src_arr, dst=dst_arr, weight=2.0) diff --git a/tests/nbackend/kernelcreation/test_index_kernels.py b/tests/nbackend/kernelcreation/test_index_kernels.py index e8a32b6b4..4c4c6b8b8 100644 --- a/tests/nbackend/kernelcreation/test_index_kernels.py +++ b/tests/nbackend/kernelcreation/test_index_kernels.py @@ -5,7 +5,6 @@ import numpy as np from pystencils import Assignment, Field, FieldType, AssignmentCollection from pystencils.nbackend.kernelcreation import create_kernel, CreateKernelConfig -from pystencils.cpu.cpujit import compile_and_load def test_indexed_kernel(): arr = np.zeros((3, 4)) @@ -23,7 +22,7 @@ def test_indexed_kernel(): options = CreateKernelConfig(index_field=index_field) ast = create_kernel(update_rule, options) - kernel = compile_and_load(ast) + kernel = ast.compile() kernel(f=arr, index=index_arr) -- GitLab