Skip to content
Snippets Groups Projects
Commit 2bcd290c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix test_freeze, and minor refactors

parent 6a934814
No related merge requests found
Pipeline #61512 canceled with stages
in 46 seconds
......@@ -15,7 +15,10 @@ from .iteration_space import (
from .transformations import EraseAnonymousStructTypes
def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions):
def create_kernel(
assignments: AssignmentCollection,
options: KernelCreationOptions = KernelCreationOptions(),
):
ctx = KernelCreationContext(options)
analysis = KernelAnalysis(ctx)
......
#%%
import sympy as sp
import pymbolic.primitives as pb
......@@ -40,12 +41,12 @@ def test_freeze_fields():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
start = PsTypedConstant(0, ctx.index_dtype)
stop = PsTypedConstant(42, ctx.index_dtype)
step = PsTypedConstant(1, ctx.index_dtype)
zero = PsTypedConstant(0, ctx.index_dtype)
forty_two = PsTypedConstant(42, ctx.index_dtype)
one = PsTypedConstant(1, ctx.index_dtype)
counter = PsTypedVariable("ctr", ctx.index_dtype)
ispace = FullIterationSpace(
ctx, [FullIterationSpace.Dimension(start, stop, step, counter)]
ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)]
)
ctx.set_iteration_space(ispace)
......@@ -59,9 +60,12 @@ def test_freeze_fields():
fasm = freeze(asm)
lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0])
rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0])
lhs = PsArrayAccess(f_arr.base_pointer, pb.Sum((counter * f_arr.strides[0], zero)))
rhs = PsArrayAccess(g_arr.base_pointer, pb.Sum((counter * g_arr.strides[0], zero)))
should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
assert fasm == should
#%%
test_freeze_fields()
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment