Skip to content
Snippets Groups Projects
Commit f1a3714c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code format and export

parent 0f0e0dfe
No related merge requests found
Pipeline #63740 failed with stages
in 3 minutes and 46 seconds
...@@ -10,6 +10,7 @@ from .cache import clear_cache ...@@ -10,6 +10,7 @@ from .cache import clear_cache
from .config import CreateKernelConfig from .config import CreateKernelConfig
from .kernel_decorator import kernel, kernel_config from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction
from .slicing import make_slice from .slicing import make_slice
from .spatial_coordinates import ( from .spatial_coordinates import (
x_, x_,
...@@ -36,6 +37,7 @@ __all__ = [ ...@@ -36,6 +37,7 @@ __all__ = [
"make_slice", "make_slice",
"CreateKernelConfig", "CreateKernelConfig",
"create_kernel", "create_kernel",
"KernelFunction",
"Target", "Target",
"Backend", "Backend",
"show_code", "show_code",
......
...@@ -2,7 +2,13 @@ from typing import cast ...@@ -2,7 +2,13 @@ from typing import cast
from .enums import Target from .enums import Target
from .config import CreateKernelConfig from .config import CreateKernelConfig
from .backend import KernelFunction, KernelParameter, FieldShapeParam, FieldStrideParam, FieldPointerParam from .backend import (
KernelFunction,
KernelParameter,
FieldShapeParam,
FieldStrideParam,
FieldPointerParam,
)
from .backend.symbols import PsSymbol from .backend.symbols import PsSymbol
from .backend.jit import JitBase from .backend.jit import JitBase
from .backend.ast.structural import PsBlock from .backend.ast.structural import PsBlock
...@@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment ...@@ -26,6 +32,7 @@ from .sympyextensions import AssignmentCollection, Assignment
__all__ = ["create_kernel"] __all__ = ["create_kernel"]
def create_kernel( def create_kernel(
assignments: AssignmentCollection | list[Assignment], assignments: AssignmentCollection | list[Assignment],
config: CreateKernelConfig = CreateKernelConfig(), config: CreateKernelConfig = CreateKernelConfig(),
...@@ -81,10 +88,18 @@ def create_kernel( ...@@ -81,10 +88,18 @@ def create_kernel(
# - Loop Splitting, Tiling, Blocking # - Loop Splitting, Tiling, Blocking
assert config.jit is not None assert config.jit is not None
return create_kernel_function(ctx, kernel_ast, config.function_name, config.target, config.jit) return create_kernel_function(
ctx, kernel_ast, config.function_name, config.target, config.jit
)
def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, target_spec: Target, jit: JitBase): def create_kernel_function(
ctx: KernelCreationContext,
body: PsBlock,
name: str,
target_spec: Target,
jit: JitBase,
):
undef_symbols = collect_undefined_symbols(body) undef_symbols = collect_undefined_symbols(body)
params = [] params = []
...@@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str, ...@@ -101,18 +116,12 @@ def create_kernel_function(ctx: KernelCreationContext, body: PsBlock, name: str,
params.append(FieldPointerParam(name, symb.get_dtype(), field)) params.append(FieldPointerParam(name, symb.get_dtype(), field))
case PsSymbol(name, _): case PsSymbol(name, _):
params.append(KernelParameter(name, symb.get_dtype())) params.append(KernelParameter(name, symb.get_dtype()))
params.sort(key=lambda p: p.name) params.sort(key=lambda p: p.name)
req_headers = collect_required_headers(body) req_headers = collect_required_headers(body)
req_headers |= ctx.required_headers req_headers |= ctx.required_headers
return KernelFunction( return KernelFunction(
body, body, target_spec, name, params, req_headers, ctx.constraints, jit
target_spec,
name,
params,
req_headers,
ctx.constraints,
jit
) )
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment