diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index f95619ed76180180321117772bbe1ede597e1d5b..705d10e8244235f08ac334c269149a7a168adb65 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -15,7 +15,10 @@ from .iteration_space import ( from .transformations import EraseAnonymousStructTypes -def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions): +def create_kernel( + assignments: AssignmentCollection, + options: KernelCreationOptions = KernelCreationOptions(), +): ctx = KernelCreationContext(options) analysis = KernelAnalysis(ctx) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index db8f4feb2f6f2fdcb713149487b39ac3b3059ec8..7e9f397fd4b9f5f7efb38a9295db882bdf97658e 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,3 +1,4 @@ +#%% import sympy as sp import pymbolic.primitives as pb @@ -40,12 +41,12 @@ def test_freeze_fields(): options = KernelCreationOptions() ctx = KernelCreationContext(options) - start = PsTypedConstant(0, ctx.index_dtype) - stop = PsTypedConstant(42, ctx.index_dtype) - step = PsTypedConstant(1, ctx.index_dtype) + zero = PsTypedConstant(0, ctx.index_dtype) + forty_two = PsTypedConstant(42, ctx.index_dtype) + one = PsTypedConstant(1, ctx.index_dtype) counter = PsTypedVariable("ctr", ctx.index_dtype) ispace = FullIterationSpace( - ctx, [FullIterationSpace.Dimension(start, stop, step, counter)] + ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)] ) ctx.set_iteration_space(ispace) @@ -59,9 +60,12 @@ def test_freeze_fields(): fasm = freeze(asm) - lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0]) - rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0]) + lhs = PsArrayAccess(f_arr.base_pointer, pb.Sum((counter * f_arr.strides[0], zero))) + rhs = PsArrayAccess(g_arr.base_pointer, pb.Sum((counter * g_arr.strides[0], zero))) should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs)) assert fasm == should + +#%% +test_freeze_fields() \ No newline at end of file