From c21436bcea8617a0e0bfd27eab2364c597020a5e Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 14 Aug 2019 07:57:05 +0200
Subject: [PATCH] Fixup: also accept non-array parameters for OpenCL jit
 arguments

---
 pystencils/opencl/opencljit.py  |  2 +-
 pystencils_tests/test_opencl.py | 53 +++++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+), 1 deletion(-)

diff --git a/pystencils/opencl/opencljit.py b/pystencils/opencl/opencljit.py
index dd0660667..a68174961 100644
--- a/pystencils/opencl/opencljit.py
+++ b/pystencils/opencl/opencljit.py
@@ -64,7 +64,7 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen
                                                                                   block_and_thread_numbers['grid']))
 
             args = _build_numpy_argument_list(parameters, full_arguments)
-            args = [a.data for a in args if hasattr(a, 'data')]
+            args = [a.data if hasattr(a, 'data') else a for a in args]
             cache[key] = (args, block_and_thread_numbers)
             cache_values.append(kwargs)  # keep objects alive such that ids remain unique
             func(opencl_queue, block_and_thread_numbers['grid'], block_and_thread_numbers['block'], *args)
diff --git a/pystencils_tests/test_opencl.py b/pystencils_tests/test_opencl.py
index 5aa6185a6..24ef56f9b 100644
--- a/pystencils_tests/test_opencl.py
+++ b/pystencils_tests/test_opencl.py
@@ -142,5 +142,58 @@ def test_opencl_jit():
     assert np.allclose(result_cuda, result_opencl)
 
 
+@pytest.mark.skipif(not HAS_OPENCL, reason="Test requires pyopencl")
+def test_opencl_jit_with_parameter():
+    z, y, x = pystencils.fields("z, y, x: [2d]")
+
+    a = sp.Symbol('a')
+    assignments = pystencils.AssignmentCollection({
+        z[0, 0]: x[0, 0] * sp.log(x[0, 0] * y[0, 0]) + a
+    })
+
+    print(assignments)
+
+    ast = pystencils.create_kernel(assignments, target='gpu')
+
+    print(ast)
+
+    code = pystencils.show_code(ast, custom_backend=CudaBackend())
+    print(code)
+    opencl_code = pystencils.show_code(ast, custom_backend=OpenClBackend())
+    print(opencl_code)
+
+    cuda_kernel = ast.compile()
+    assert cuda_kernel is not None
+
+    import pycuda.gpuarray as gpuarray
+
+    x_cpu = np.random.rand(20, 30)
+    y_cpu = np.random.rand(20, 30)
+    z_cpu = np.random.rand(20, 30)
+
+    x = gpuarray.to_gpu(x_cpu)
+    y = gpuarray.to_gpu(y_cpu)
+    z = gpuarray.to_gpu(z_cpu)
+    cuda_kernel(x=x, y=y, z=z, a=5.)
+
+    result_cuda = z.get()
+
+    import pyopencl.array as array
+    ctx = cl.create_some_context(0)
+    queue = cl.CommandQueue(ctx)
+
+    x = array.to_device(queue, x_cpu)
+    y = array.to_device(queue, y_cpu)
+    z = array.to_device(queue, z_cpu)
+
+    opencl_kernel = make_python_function(ast, queue, ctx)
+    assert opencl_kernel is not None
+    opencl_kernel(x=x, y=y, z=z, a=5.)
+
+    result_opencl = z.get(queue)
+
+    assert np.allclose(result_cuda, result_opencl)
+
+
 if __name__ == '__main__':
     test_opencl_jit()
-- 
GitLab