From de710402ff0bd94ea13b9f2e3d143bad7d1f7502 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 14 Nov 2019 16:47:36 +0100
Subject: [PATCH] Allow default_target=='opencl' in SerialDataHandling

---
 pystencils/datahandling/serial_datahandling.py | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py
index 931a7df30..5abc3cc2f 100644
--- a/pystencils/datahandling/serial_datahandling.py
+++ b/pystencils/datahandling/serial_datahandling.py
@@ -16,6 +16,9 @@ from pystencils.utils import DotDict
 
 class SerialDataHandling(DataHandling):
 
+    _GPU_LIKE_TARGETS = ['gpu', 'opencl']
+    _GPU_LIKE_BACKENDS = ['gpucuda', 'opencl']
+
     def __init__(self,
                  domain_size: Sequence[int],
                  default_ghost_layers: int = 1,
@@ -48,17 +51,16 @@ class SerialDataHandling(DataHandling):
         self._opencl_queue = opencl_queue
         self._opencl_ctx = opencl_ctx
 
-        if array_handler:
-            self.array_handler = array_handler
-        else:
+        if not array_handler:
             try:
                 self.array_handler = PyCudaArrayHandler()
             except Exception:
                 self.array_handler = None
 
             if default_target == 'opencl' or opencl_queue:
-                default_target = 'gpu'
                 self.array_handler = PyOpenClArrayHandler(opencl_queue)
+        else:
+            self.array_handler = array_handler
 
         if periodicity is None or periodicity is False:
             periodicity = [False] * self.dim
@@ -99,7 +101,7 @@ class SerialDataHandling(DataHandling):
         if layout is None:
             layout = self.default_layout
         if gpu is None:
-            gpu = self.default_target == 'gpu'
+            gpu = self.default_target in self._GPU_LIKE_TARGETS
 
         kwargs = {
             'shape': tuple(s + 2 * ghost_layers for s in self._domainSize),
@@ -239,13 +241,12 @@ class SerialDataHandling(DataHandling):
             self.to_gpu(name)
 
     def run_kernel(self, kernel_function, **kwargs):
-        arrays = self.gpu_arrays if kernel_function.ast.backend == 'gpucuda'  \
-            or kernel_function.ast.backend == 'opencl' else self.cpu_arrays
+        arrays = self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays
         kernel_function(**arrays, **kwargs)
 
     def get_kernel_kwargs(self, kernel_function, **kwargs):
         result = {}
-        result.update(self.gpu_arrays if kernel_function.ast.backend == 'gpucuda' else self.cpu_arrays)
+        result.update(self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays)
         result.update(kwargs)
         return [result]
 
-- 
GitLab