From 2bcd290c84f5cd8e2c965681a479bf5b4fce76c7 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 31 Jan 2024 22:42:22 +0100 Subject: [PATCH] fix test_freeze, and minor refactors --- .../nbackend/kernelcreation/kernelcreation.py | 5 ++++- tests/nbackend/kernelcreation/test_freeze.py | 16 ++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/pystencils/nbackend/kernelcreation/kernelcreation.py b/src/pystencils/nbackend/kernelcreation/kernelcreation.py index f95619ed7..705d10e82 100644 --- a/src/pystencils/nbackend/kernelcreation/kernelcreation.py +++ b/src/pystencils/nbackend/kernelcreation/kernelcreation.py @@ -15,7 +15,10 @@ from .iteration_space import ( from .transformations import EraseAnonymousStructTypes -def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions): +def create_kernel( + assignments: AssignmentCollection, + options: KernelCreationOptions = KernelCreationOptions(), +): ctx = KernelCreationContext(options) analysis = KernelAnalysis(ctx) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index db8f4feb2..7e9f397fd 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,3 +1,4 @@ +#%% import sympy as sp import pymbolic.primitives as pb @@ -40,12 +41,12 @@ def test_freeze_fields(): options = KernelCreationOptions() ctx = KernelCreationContext(options) - start = PsTypedConstant(0, ctx.index_dtype) - stop = PsTypedConstant(42, ctx.index_dtype) - step = PsTypedConstant(1, ctx.index_dtype) + zero = PsTypedConstant(0, ctx.index_dtype) + forty_two = PsTypedConstant(42, ctx.index_dtype) + one = PsTypedConstant(1, ctx.index_dtype) counter = PsTypedVariable("ctr", ctx.index_dtype) ispace = FullIterationSpace( - ctx, [FullIterationSpace.Dimension(start, stop, step, counter)] + ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)] ) ctx.set_iteration_space(ispace) @@ -59,9 +60,12 @@ def test_freeze_fields(): fasm = freeze(asm) - lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0]) - rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0]) + lhs = PsArrayAccess(f_arr.base_pointer, pb.Sum((counter * f_arr.strides[0], zero))) + rhs = PsArrayAccess(g_arr.base_pointer, pb.Sum((counter * g_arr.strides[0], zero))) should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs)) assert fasm == should + +#%% +test_freeze_fields() \ No newline at end of file -- GitLab