Skip to content
Snippets Groups Projects
Commit 2bcd290c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix test_freeze, and minor refactors

parent 6a934814
Branches
Tags
No related merge requests found
Pipeline #61512 canceled with stages
in 46 seconds
...@@ -15,7 +15,10 @@ from .iteration_space import ( ...@@ -15,7 +15,10 @@ from .iteration_space import (
from .transformations import EraseAnonymousStructTypes from .transformations import EraseAnonymousStructTypes
def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions): def create_kernel(
assignments: AssignmentCollection,
options: KernelCreationOptions = KernelCreationOptions(),
):
ctx = KernelCreationContext(options) ctx = KernelCreationContext(options)
analysis = KernelAnalysis(ctx) analysis = KernelAnalysis(ctx)
......
#%%
import sympy as sp import sympy as sp
import pymbolic.primitives as pb import pymbolic.primitives as pb
...@@ -40,12 +41,12 @@ def test_freeze_fields(): ...@@ -40,12 +41,12 @@ def test_freeze_fields():
options = KernelCreationOptions() options = KernelCreationOptions()
ctx = KernelCreationContext(options) ctx = KernelCreationContext(options)
start = PsTypedConstant(0, ctx.index_dtype) zero = PsTypedConstant(0, ctx.index_dtype)
stop = PsTypedConstant(42, ctx.index_dtype) forty_two = PsTypedConstant(42, ctx.index_dtype)
step = PsTypedConstant(1, ctx.index_dtype) one = PsTypedConstant(1, ctx.index_dtype)
counter = PsTypedVariable("ctr", ctx.index_dtype) counter = PsTypedVariable("ctr", ctx.index_dtype)
ispace = FullIterationSpace( ispace = FullIterationSpace(
ctx, [FullIterationSpace.Dimension(start, stop, step, counter)] ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)]
) )
ctx.set_iteration_space(ispace) ctx.set_iteration_space(ispace)
...@@ -59,9 +60,12 @@ def test_freeze_fields(): ...@@ -59,9 +60,12 @@ def test_freeze_fields():
fasm = freeze(asm) fasm = freeze(asm)
lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0]) lhs = PsArrayAccess(f_arr.base_pointer, pb.Sum((counter * f_arr.strides[0], zero)))
rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0]) rhs = PsArrayAccess(g_arr.base_pointer, pb.Sum((counter * g_arr.strides[0], zero)))
should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs)) should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
assert fasm == should assert fasm == should
#%%
test_freeze_fields()
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment