diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 59d6a32e5a689a61cda7546c5230ec233cbaa5bf..27d3df76ce8740ad7b0ca0885ba95de23c32127f 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -10,6 +10,7 @@ from .cache import clear_cache from .config import CreateKernelConfig from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel +from .backend.kernelfunction import KernelFunction from .slicing import make_slice from .spatial_coordinates import ( x_, @@ -36,6 +37,7 @@ __all__ = [ "make_slice", "CreateKernelConfig", "create_kernel", + "KernelFunction", "Target", "Backend", "show_code", diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 93b2c998dbd6901f9a0f78725712ce9a087e1160..1e365cfed6dea8578e96f962b6fa9faa0e0f8930 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -2,7 +2,13 @@ from typing import cast from .enums import Target from .config import CreateKernelConfig -from .backend import KernelFunction, KernelParameter, FieldShapeParam, FieldStrideParam, FieldPointerParam +from .backend import ( + KernelFunction, + KernelParameter, + FieldShapeParam, + FieldStrideParam, + FieldPointerParam, +) from .backend.symbols import PsSymbol from .backend.jit import JitBase from .backend.ast.structural import PsBlock @@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment __all__ = ["create_kernel"] + def create_kernel( assignments: AssignmentCollection | list[Assignment], config: CreateKernelConfig = CreateKernelConfig(), @@ -81,10 +88,18 @@ def create_kernel( # - Loop Splitting, Tiling, Blocking assert config.jit is not None - return create_kernel_function(ctx, kernel_ast, config.function_name, config.target, config.jit) + return create_kernel_function( + ctx, kernel_ast, config.function_name, config.target, config.jit + ) -def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, target_spec: Target, jit: JitBase): +def create_kernel_function( + ctx: KernelCreationContext, + body: PsBlock, + name: str, + target_spec: Target, + jit: JitBase, +): undef_symbols = collect_undefined_symbols(body) params = [] @@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, params.append(FieldPointerParam(name, symb.get_dtype(), field)) case PsSymbol(name, _): params.append(KernelParameter(name, symb.get_dtype())) - + params.sort(key=lambda p: p.name) req_headers = collect_required_headers(body) req_headers |= ctx.required_headers return KernelFunction( - body, - target_spec, - name, - params, - req_headers, - ctx.constraints, - jit + body, target_spec, name, params, req_headers, ctx.constraints, jit )