From b6dfd5d35f468be49ec77a6537b04b55b27d5a8c Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 16 Mar 2018 12:18:50 +0100
Subject: [PATCH] Simpler GPU setup: data handling has a default target now

-> one place to switch between cpu and gpu
---
 datahandling/datahandling_interface.py |  6 +++---
 datahandling/parallel_datahandling.py  | 14 +++++++++++---
 datahandling/serial_datahandling.py    | 15 +++++++++++----
 3 files changed, 25 insertions(+), 10 deletions(-)

diff --git a/datahandling/datahandling_interface.py b/datahandling/datahandling_interface.py
index 25914e990..849b9d6e9 100644
--- a/datahandling/datahandling_interface.py
+++ b/datahandling/datahandling_interface.py
@@ -30,7 +30,7 @@ class DataHandling(ABC):
         """Returns tuple of booleans for x,y,(z) directions with True if domain is periodic in that direction"""
 
     @abstractmethod
-    def addArray(self, name, fSize=1, dtype=np.float64, latexName=None, ghostLayers=None, layout=None, cpu=True, gpu=False):
+    def addArray(self, name, fSize=1, dtype=np.float64, latexName=None, ghostLayers=None, layout=None, cpu=True, gpu=None):
         """
         Adds a (possibly distributed) array to the handling that can be accessed using the given name.
         For each array a symbolic field is available via the 'fields' dictionary
@@ -46,7 +46,7 @@ class DataHandling(ABC):
         :param layout: memory layout of array, either structure of arrays 'SoA' or array of structures 'AoS'.
                        this is only important if fSize > 1
         :param cpu: allocate field on the CPU
-        :param gpu: allocate field on the GPU
+        :param gpu: allocate field on the GPU, if None, a GPU field is allocated if defaultTarget is 'gpu'
         """
 
     @abstractmethod
@@ -56,7 +56,7 @@ class DataHandling(ABC):
         """
 
     @abstractmethod
-    def addArrayLike(self, name, nameOfTemplateField, latexName=None, cpu=True, gpu=False):
+    def addArrayLike(self, name, nameOfTemplateField, latexName=None, cpu=True, gpu=None):
         """
         Adds an array with the same parameters (number of ghost layers, fSize, dtype) as existing array
         :param name: name of new array
diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py
index c0ba80cd6..88c3c837e 100644
--- a/datahandling/parallel_datahandling.py
+++ b/datahandling/parallel_datahandling.py
@@ -11,7 +11,7 @@ class ParallelDataHandling(DataHandling):
     GPU_DATA_PREFIX = "gpu_"
     VTK_COUNTER = 0
 
-    def __init__(self, blocks, defaultGhostLayers=1, defaultLayout='SoA', dim=3):
+    def __init__(self, blocks, defaultGhostLayers=1, defaultLayout='SoA', dim=3, defaultTarget='cpu'):
         """
         Creates data handling based on waLBerla block storage
 
@@ -21,6 +21,8 @@ class ParallelDataHandling(DataHandling):
         :param dim: dimension of scenario,
                     waLBerla always uses three dimensions, so if dim=2 the extend of the
                     z coordinate of blocks has to be 1
+        :param defaultTarget: either 'cpu' or 'gpu' . If set to 'gpu' for each array also a GPU version is allocated
+                              if not overwritten in addArray, and synchronization functions are for the GPU by default
         """
         super(ParallelDataHandling, self).__init__()
         assert dim in (2, 3)
@@ -44,6 +46,7 @@ class ParallelDataHandling(DataHandling):
 
         if self._dim == 2:
             assert self.blocks.getDomainCellBB().size[2] == 1
+        self.defaultTarget = defaultTarget
 
     @property
     def dim(self):
@@ -81,9 +84,11 @@ class ParallelDataHandling(DataHandling):
         self._customDataNames.append(name)
 
     def addArray(self, name, fSize=1, dtype=np.float64, latexName=None, ghostLayers=None,
-                 layout=None, cpu=True, gpu=False):
+                 layout=None, cpu=True, gpu=None):
         if ghostLayers is None:
             ghostLayers = self.defaultGhostLayers
+        if gpu is None:
+            gpu = self.defaultTarget == 'gpu'
         if layout is None:
             layout = self.defaultLayout
         if len(self.blocks) == 0:
@@ -139,7 +144,7 @@ class ParallelDataHandling(DataHandling):
     def customDataNames(self):
         return tuple(self._customDataNames)
 
-    def addArrayLike(self, name, nameOfTemplateField, latexName=None, cpu=True, gpu=False):
+    def addArrayLike(self, name, nameOfTemplateField, latexName=None, cpu=True, gpu=None):
         return self.addArray(name, latexName=latexName, cpu=cpu, gpu=gpu, **self._fieldInformation[nameOfTemplateField])
 
     def swap(self, name1, name2, gpu=False):
@@ -260,6 +265,9 @@ class ParallelDataHandling(DataHandling):
         return self.synchronizationFunction(names, stencil, 'gpu', buffered)
 
     def synchronizationFunction(self, names, stencil=None, target='cpu', buffered=True):
+        if target is None:
+            target = self.defaultTarget
+
         if stencil is None:
             stencil = 'D3Q27' if self.dim == 3 else 'D2Q9'
 
diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py
index e0ce53770..f31320aef 100644
--- a/datahandling/serial_datahandling.py
+++ b/datahandling/serial_datahandling.py
@@ -18,13 +18,15 @@ except ImportError:
 
 class SerialDataHandling(DataHandling):
 
-    def __init__(self, domainSize, defaultGhostLayers=1, defaultLayout='SoA', periodicity=False):
+    def __init__(self, domainSize, defaultGhostLayers=1, defaultLayout='SoA', periodicity=False,  defaultTarget='cpu'):
         """
         Creates a data handling for single node simulations
 
         :param domainSize: size of the spatial domain as tuple
         :param defaultGhostLayers: nr of ghost layers used if not specified in add() method
         :param defaultLayout: layout used if no layout is given to add() method
+        :param defaultTarget: either 'cpu' or 'gpu' . If set to 'gpu' for each array also a GPU version is allocated
+                              if not overwritten in addArray, and synchronization functions are for the GPU by default
         """
         super(SerialDataHandling, self).__init__()
         self._domainSize = tuple(domainSize)
@@ -44,6 +46,7 @@ class SerialDataHandling(DataHandling):
 
         self._periodicity = periodicity
         self._fieldInformation = {}
+        self.defaultTarget = defaultTarget
 
     @property
     def dim(self):
@@ -68,11 +71,13 @@ class SerialDataHandling(DataHandling):
         return self._fieldInformation[name]['fSize']
 
     def addArray(self, name, fSize=1, dtype=np.float64, latexName=None, ghostLayers=None, layout=None,
-                 cpu=True, gpu=False):
+                 cpu=True, gpu=None):
         if ghostLayers is None:
             ghostLayers = self.defaultGhostLayers
         if layout is None:
             layout = self.defaultLayout
+        if gpu is None:
+            gpu = self.defaultTarget == 'gpu'
 
         kwargs = {
             'shape': tuple(s + 2 * ghostLayers for s in self._domainSize),
@@ -132,7 +137,7 @@ class SerialDataHandling(DataHandling):
     def hasData(self, name):
         return name in self.fields
 
-    def addArrayLike(self, name, nameOfTemplateField, latexName=None, cpu=True, gpu=False):
+    def addArrayLike(self, name, nameOfTemplateField, latexName=None, cpu=True, gpu=None):
         return self.addArray(name, latexName=latexName, cpu=cpu, gpu=gpu, **self._fieldInformation[nameOfTemplateField])
 
     def iterate(self, sliceObj=None, gpu=False, ghostLayers=True, innerGhostLayers=True):
@@ -228,7 +233,9 @@ class SerialDataHandling(DataHandling):
     def synchronizationFunctionGPU(self, names, stencilName=None, **kwargs):
         return self.synchronizationFunction(names, stencilName, 'gpu')
 
-    def synchronizationFunction(self, names, stencil=None, target='cpu'):
+    def synchronizationFunction(self, names, stencil=None, target=None):
+        if target is None:
+            target = self.defaultTarget
         assert target in ('cpu', 'gpu')
         if not hasattr(names, '__len__') or type(names) is str:
             names = [names]
-- 
GitLab