Commit c21436bc authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fixup: also accept non-array parameters for OpenCL jit arguments

parent b03c498d
......@@ -64,7 +64,7 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen
block_and_thread_numbers['grid']))
args = _build_numpy_argument_list(parameters, full_arguments)
args = [a.data for a in args if hasattr(a, 'data')]
args = [a.data if hasattr(a, 'data') else a for a in args]
cache[key] = (args, block_and_thread_numbers)
cache_values.append(kwargs) # keep objects alive such that ids remain unique
func(opencl_queue, block_and_thread_numbers['grid'], block_and_thread_numbers['block'], *args)
......
......@@ -142,5 +142,58 @@ def test_opencl_jit():
assert np.allclose(result_cuda, result_opencl)
@pytest.mark.skipif(not HAS_OPENCL, reason="Test requires pyopencl")
def test_opencl_jit_with_parameter():
z, y, x = pystencils.fields("z, y, x: [2d]")
a = sp.Symbol('a')
assignments = pystencils.AssignmentCollection({
z[0, 0]: x[0, 0] * sp.log(x[0, 0] * y[0, 0]) + a
})
print(assignments)
ast = pystencils.create_kernel(assignments, target='gpu')
print(ast)
code = pystencils.show_code(ast, custom_backend=CudaBackend())
print(code)
opencl_code = pystencils.show_code(ast, custom_backend=OpenClBackend())
print(opencl_code)
cuda_kernel = ast.compile()
assert cuda_kernel is not None
import pycuda.gpuarray as gpuarray
x_cpu = np.random.rand(20, 30)
y_cpu = np.random.rand(20, 30)
z_cpu = np.random.rand(20, 30)
x = gpuarray.to_gpu(x_cpu)
y = gpuarray.to_gpu(y_cpu)
z = gpuarray.to_gpu(z_cpu)
cuda_kernel(x=x, y=y, z=z, a=5.)
result_cuda = z.get()
import pyopencl.array as array
ctx = cl.create_some_context(0)
queue = cl.CommandQueue(ctx)
x = array.to_device(queue, x_cpu)
y = array.to_device(queue, y_cpu)
z = array.to_device(queue, z_cpu)
opencl_kernel = make_python_function(ast, queue, ctx)
assert opencl_kernel is not None
opencl_kernel(x=x, y=y, z=z, a=5.)
result_opencl = z.get(queue)
assert np.allclose(result_cuda, result_opencl)
if __name__ == '__main__':
test_opencl_jit()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment