From 885fc9c7ec4dd46f66f3a2216d5d86be9c689cf7 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Fri, 15 Nov 2019 17:43:49 +0100
Subject: [PATCH] OpenCL macOS support

---
 pystencils/backends/opencl_backend.py |  2 +-
 pystencils/opencl/opencljit.py        | 10 ++++++++++
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py
index 4217e9384..b5da806bb 100644
--- a/pystencils/backends/opencl_backend.py
+++ b/pystencils/backends/opencl_backend.py
@@ -73,7 +73,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
         function_name, dimension = tuple(symbol_name.split("."))
         dimension = self.DIMENSION_MAPPING[dimension]
         function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
-        return f"{function_name}({dimension})"
+        return f"int({function_name}({dimension}))"
 
     def _print_TextureAccess(self, node):
         raise NotImplementedError()
diff --git a/pystencils/opencl/opencljit.py b/pystencils/opencl/opencljit.py
index f1df02936..5526c954a 100644
--- a/pystencils/opencl/opencljit.py
+++ b/pystencils/opencl/opencljit.py
@@ -30,6 +30,16 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen
     if argument_dict is None:
         argument_dict = {}
 
+    # check if double precision is supported and required
+    if any([d.double_fp_config == 0 for d in opencl_ctx.devices]):
+        for param in kernel_function_node.get_parameters():
+            if param.symbol.dtype.base_type:
+                if param.symbol.dtype.base_type.numpy_dtype == np.float64:
+                    raise ValueError('OpenCL device does not support double precision')
+            else:
+                if param.symbol.dtype.numpy_dtype == np.float64:
+                    raise ValueError('OpenCL device does not support double precision')
+
     # Changing of kernel name necessary since compilation with default name "kernel" is not possible (OpenCL keyword!)
     kernel_function_node.function_name = "opencl_" + kernel_function_node.function_name
     header_list = ['"opencl_stdint.h"'] + list(get_headers(kernel_function_node))
-- 
GitLab