diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index 59d6a32e5a689a61cda7546c5230ec233cbaa5bf..27d3df76ce8740ad7b0ca0885ba95de23c32127f 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -10,6 +10,7 @@ from .cache import clear_cache
 from .config import CreateKernelConfig
 from .kernel_decorator import kernel, kernel_config
 from .kernelcreation import create_kernel
+from .backend.kernelfunction import KernelFunction
 from .slicing import make_slice
 from .spatial_coordinates import (
     x_,
@@ -36,6 +37,7 @@ __all__ = [
     "make_slice",
     "CreateKernelConfig",
     "create_kernel",
+    "KernelFunction",
     "Target",
     "Backend",
     "show_code",
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 93b2c998dbd6901f9a0f78725712ce9a087e1160..1e365cfed6dea8578e96f962b6fa9faa0e0f8930 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -2,7 +2,13 @@ from typing import cast
 
 from .enums import Target
 from .config import CreateKernelConfig
-from .backend import KernelFunction, KernelParameter, FieldShapeParam, FieldStrideParam, FieldPointerParam
+from .backend import (
+    KernelFunction,
+    KernelParameter,
+    FieldShapeParam,
+    FieldStrideParam,
+    FieldPointerParam,
+)
 from .backend.symbols import PsSymbol
 from .backend.jit import JitBase
 from .backend.ast.structural import PsBlock
@@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment
 
 __all__ = ["create_kernel"]
 
+
 def create_kernel(
     assignments: AssignmentCollection | list[Assignment],
     config: CreateKernelConfig = CreateKernelConfig(),
@@ -81,10 +88,18 @@ def create_kernel(
     #     - Loop Splitting, Tiling, Blocking
 
     assert config.jit is not None
-    return create_kernel_function(ctx, kernel_ast, config.function_name, config.target, config.jit)
+    return create_kernel_function(
+        ctx, kernel_ast, config.function_name, config.target, config.jit
+    )
 
 
-def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, target_spec: Target, jit: JitBase):
+def create_kernel_function(
+    ctx: KernelCreationContext,
+    body: PsBlock,
+    name: str,
+    target_spec: Target,
+    jit: JitBase,
+):
     undef_symbols = collect_undefined_symbols(body)
 
     params = []
@@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str,
                 params.append(FieldPointerParam(name, symb.get_dtype(), field))
             case PsSymbol(name, _):
                 params.append(KernelParameter(name, symb.get_dtype()))
-    
+
     params.sort(key=lambda p: p.name)
 
     req_headers = collect_required_headers(body)
     req_headers |= ctx.required_headers
 
     return KernelFunction(
-        body,
-        target_spec,
-        name,
-        params,
-        req_headers,
-        ctx.constraints,
-        jit
+        body, target_spec, name, params, req_headers, ctx.constraints, jit
     )