From f1a3714cbaffd5fff689b76a3b0af9336b44b8b3 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 6 Mar 2024 11:46:29 +0100 Subject: [PATCH] code format and export --- src/pystencils/__init__.py | 2 ++ src/pystencils/kernelcreation.py | 31 ++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 59d6a32e5..27d3df76c 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -10,6 +10,7 @@ from .cache import clear_cache from .config import CreateKernelConfig from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel +from .backend.kernelfunction import KernelFunction from .slicing import make_slice from .spatial_coordinates import ( x_, @@ -36,6 +37,7 @@ __all__ = [ "make_slice", "CreateKernelConfig", "create_kernel", + "KernelFunction", "Target", "Backend", "show_code", diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 93b2c998d..1e365cfed 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -2,7 +2,13 @@ from typing import cast from .enums import Target from .config import CreateKernelConfig -from .backend import KernelFunction, KernelParameter, FieldShapeParam, FieldStrideParam, FieldPointerParam +from .backend import ( + KernelFunction, + KernelParameter, + FieldShapeParam, + FieldStrideParam, + FieldPointerParam, +) from .backend.symbols import PsSymbol from .backend.jit import JitBase from .backend.ast.structural import PsBlock @@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment __all__ = ["create_kernel"] + def create_kernel( assignments: AssignmentCollection | list[Assignment], config: CreateKernelConfig = CreateKernelConfig(), @@ -81,10 +88,18 @@ def create_kernel( # - Loop Splitting, Tiling, Blocking assert config.jit is not None - return create_kernel_function(ctx, kernel_ast, config.function_name, config.target, config.jit) + return create_kernel_function( + ctx, kernel_ast, config.function_name, config.target, config.jit + ) -def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, target_spec: Target, jit: JitBase): +def create_kernel_function( + ctx: KernelCreationContext, + body: PsBlock, + name: str, + target_spec: Target, + jit: JitBase, +): undef_symbols = collect_undefined_symbols(body) params = [] @@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, params.append(FieldPointerParam(name, symb.get_dtype(), field)) case PsSymbol(name, _): params.append(KernelParameter(name, symb.get_dtype())) - + params.sort(key=lambda p: p.name) req_headers = collect_required_headers(body) req_headers |= ctx.required_headers return KernelFunction( - body, - target_spec, - name, - params, - req_headers, - ctx.constraints, - jit + body, target_spec, name, params, req_headers, ctx.constraints, jit ) -- GitLab