From bb21c9b145838208d41ff9f8bb83546d95c691e7 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 14 Nov 2019 19:29:46 +0100
Subject: [PATCH] Add weird construction to handle OpenCL boundary handling
 (don't show this to your children or students)

---
 pystencils/boundaries/boundaryhandling.py     | 28 +++++++++++++++----
 .../datahandling/datahandling_interface.py    |  3 ++
 .../datahandling/serial_datahandling.py       |  3 --
 3 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/pystencils/boundaries/boundaryhandling.py b/pystencils/boundaries/boundaryhandling.py
index d258de4b3..e19e24ecb 100644
--- a/pystencils/boundaries/boundaryhandling.py
+++ b/pystencils/boundaries/boundaryhandling.py
@@ -87,8 +87,26 @@ class BoundaryHandling:
         fi = flag_interface
         self.flag_interface = fi if fi is not None else FlagInterface(data_handling, name + "Flags")
 
-        gpu = self._target == 'gpu'
-        data_handling.add_custom_class(self._index_array_name, self.IndexFieldBlockData, cpu=True, gpu=gpu)
+        gpu = self._target in self._data_handling._GPU_LIKE_TARGETS
+        class_ = self.IndexFieldBlockData
+        if self._target == 'opencl':
+            def opencl_to_device(gpu_version, cpu_version):
+                from pyopencl import array
+                gpu_version = gpu_version.boundary_object_to_index_list
+                cpu_version = cpu_version.boundary_object_to_index_list
+                for obj, cpu_arr in cpu_version.items():
+                    if obj not in gpu_version or gpu_version[obj].shape != cpu_arr.shape:
+                        from pystencils.opencl.opencljit import get_global_cl_queue
+
+                        queue = self._data_handling._opencl_queue or get_global_cl_queue()
+                        gpu_version[obj] = array.to_device(queue, cpu_arr)
+                    else:
+                        gpu_version[obj].set(cpu_arr)
+
+            class_ = type('opencl_class', (self.IndexFieldBlockData,), {
+                'to_gpu': opencl_to_device
+            })
+        data_handling.add_custom_class(self._index_array_name, class_, cpu=True, gpu=gpu)
 
     @property
     def data_handling(self):
@@ -204,7 +222,7 @@ class BoundaryHandling:
         if self._dirty:
             self.prepare()
 
-        for b in self._data_handling.iterate(gpu=self._target == 'gpu'):
+        for b in self._data_handling.iterate(gpu=self._target in self._data_handling._GPU_LIKE_TARGETS):
             for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
                 kwargs[self._field_name] = b[self._field_name]
                 kwargs['indexField'] = idx_arr
@@ -219,7 +237,7 @@ class BoundaryHandling:
         if self._dirty:
             self.prepare()
 
-        for b in self._data_handling.iterate(gpu=self._target == 'gpu'):
+        for b in self._data_handling.iterate(gpu=self._target in self._data_handling._GPU_LIKE_TARGETS):
             for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
                 arguments = kwargs.copy()
                 arguments[self._field_name] = b[self._field_name]
@@ -302,7 +320,7 @@ class BoundaryHandling:
     def _boundary_data_initialization(self, boundary_obj, boundary_data_setter, **kwargs):
         if boundary_obj.additional_data_init_callback:
             boundary_obj.additional_data_init_callback(boundary_data_setter, **kwargs)
-        if self._target == 'gpu':
+        if self._target in self._data_handling._GPU_LIKE_TARGETS:
             self._data_handling.to_gpu(self._index_array_name)
 
     class BoundaryInfo(object):
diff --git a/pystencils/datahandling/datahandling_interface.py b/pystencils/datahandling/datahandling_interface.py
index ba960edc1..af1a6ba1f 100644
--- a/pystencils/datahandling/datahandling_interface.py
+++ b/pystencils/datahandling/datahandling_interface.py
@@ -16,6 +16,9 @@ class DataHandling(ABC):
     'gather' function that has collects (parts of the) distributed data on a single process.
     """
 
+    _GPU_LIKE_TARGETS = ['gpu', 'opencl']
+    _GPU_LIKE_BACKENDS = ['gpucuda', 'opencl']
+
     # ---------------------------- Adding and accessing data -----------------------------------------------------------
 
     @property
diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py
index 5abc3cc2f..12c9b6f3d 100644
--- a/pystencils/datahandling/serial_datahandling.py
+++ b/pystencils/datahandling/serial_datahandling.py
@@ -16,9 +16,6 @@ from pystencils.utils import DotDict
 
 class SerialDataHandling(DataHandling):
 
-    _GPU_LIKE_TARGETS = ['gpu', 'opencl']
-    _GPU_LIKE_BACKENDS = ['gpucuda', 'opencl']
-
     def __init__(self,
                  domain_size: Sequence[int],
                  default_ghost_layers: int = 1,
-- 
GitLab