From afd3462338afa4f86794431a3be484ee9260d76b Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 1 Feb 2018 10:05:02 +0100
Subject: [PATCH] Ported most of the tutorials/demos/benchmarks

---
 datahandling/datahandling_interface.py |  3 ++-
 datahandling/parallel_datahandling.py  |  8 ++++++--
 datahandling/serial_datahandling.py    | 18 +++++++-----------
 jupytersetup.py                        |  4 ++++
 slicing.py                             |  8 ++++++++
 5 files changed, 27 insertions(+), 14 deletions(-)

diff --git a/datahandling/datahandling_interface.py b/datahandling/datahandling_interface.py
index b2faa156c..2a537206d 100644
--- a/datahandling/datahandling_interface.py
+++ b/datahandling/datahandling_interface.py
@@ -102,7 +102,7 @@ class DataHandling(ABC):
         """
 
     @abstractmethod
-    def gatherArray(self, name, sliceObj=None, allGather=False):
+    def gatherArray(self, name, sliceObj=None, allGather=False, ghostLayers=False):
         """
         Gathers part of the domain on a local process. Whenever possible use 'access' instead, since this method copies
         the distributed data to a single process which is inefficient and may exhaust the available memory
@@ -110,6 +110,7 @@ class DataHandling(ABC):
         :param name: name of the array to gather
         :param sliceObj: slice expression of the rectangular sub-part that should be gathered
         :param allGather: if False only the root process receives the result, if True all processes
+        :param ghostLayers: number of outer ghost layers to include (only available for serial data handlings)
         :return: gathered field that does not include any ghost layers, or None if gathered on another process
         """
 
diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py
index d8d558cdc..08000b8b2 100644
--- a/datahandling/parallel_datahandling.py
+++ b/datahandling/parallel_datahandling.py
@@ -4,7 +4,7 @@ from pystencils.datahandling.datahandling_interface import DataHandling
 from pystencils.parallel.blockiteration import slicedBlockIteration, blockIteration
 from pystencils.utils import DotDict
 import waLBerla as wlb
-
+import warnings
 
 class ParallelDataHandling(DataHandling):
     GPU_DATA_PREFIX = "gpu_"
@@ -152,7 +152,11 @@ class ParallelDataHandling(DataHandling):
         else:
             yield from blockIteration(self.blocks, ghostLayers, self.dim, prefix)
 
-    def gatherArray(self, name, sliceObj=None, allGather=False):
+    def gatherArray(self, name, sliceObj=None, allGather=False, ghostLayers=False):
+        if ghostLayers is not False:
+            warnings.warn("gatherArray with ghost layers is only supported in serial datahandling. "
+                          "Array without ghost layers is returned")
+
         if sliceObj is None:
             sliceObj = tuple([slice(None, None, None)] * self.dim)
         if self.dim == 2:
diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py
index 1662462a8..d422e76a6 100644
--- a/datahandling/serial_datahandling.py
+++ b/datahandling/serial_datahandling.py
@@ -16,13 +16,6 @@ except ImportError:
 
 class SerialDataHandling(DataHandling):
 
-    class _PassThroughContextManager:
-        def __init__(self, arr):
-            self.arr = arr
-
-        def __enter__(self, *args, **kwargs):
-            return self.arr
-
     def __init__(self, domainSize, defaultGhostLayers=1, defaultLayout='SoA', periodicity=False):
         """
         Creates a data handling for single node simulations
@@ -98,7 +91,6 @@ class SerialDataHandling(DataHandling):
             indexDimensions = 0
             layoutTuple = spatialLayoutStringToTuple(layout, self.dim)
 
-
         # cpuArr is always created - since there is no createPycudaArrayWithLayout()
         cpuArr = createNumpyArrayWithLayout(layout=layoutTuple, **kwargs)
         if cpu:
@@ -162,11 +154,15 @@ class SerialDataHandling(DataHandling):
         offset = tuple(s.start - ghostLayers for s in sliceObj)
         yield SerialBlock(iterDict, offset, sliceObj)
 
-    def gatherArray(self, name, sliceObj=None, **kwargs):
-        gls = self._fieldInformation[name]['ghostLayers']
+    def gatherArray(self, name, sliceObj=None, ghostLayers=False, **kwargs):
+        glToRemove = self._fieldInformation[name]['ghostLayers']
+        if isinstance(ghostLayers, int):
+            glToRemove -= ghostLayers
+        if ghostLayers is True:
+            glToRemove = 0
         arr = self.cpuArrays[name]
         indDimensions = self.fields[name].indexDimensions
-        arr = removeGhostLayers(arr, indexDimensions=indDimensions, ghostLayers=gls)
+        arr = removeGhostLayers(arr, indexDimensions=indDimensions, ghostLayers=glToRemove)
 
         if sliceObj is not None:
             sliceObj = normalizeSlice(sliceObj, arr.shape[:-indDimensions] if indDimensions > 0 else arr.shape)
diff --git a/jupytersetup.py b/jupytersetup.py
index f390c9c4b..1e46d707c 100644
--- a/jupytersetup.py
+++ b/jupytersetup.py
@@ -4,6 +4,7 @@ from IPython.display import HTML
 from tempfile import NamedTemporaryFile
 import base64
 from IPython import get_ipython
+import sympy as sp
 
 
 def log_progress(sequence, every=None, size=None, name='Items'):
@@ -176,4 +177,7 @@ def setDisplayMode(mode):
 ipython = get_ipython()
 if ipython:
     setDisplayMode('imageupdate')
+    ipython.magic("config InlineBackend.rc = { }")
     ipython.magic("matplotlib inline")
+    plt.rc('figure', figsize=(16, 6))
+    sp.init_printing()
diff --git a/slicing.py b/slicing.py
index f22ad411b..200cb0305 100644
--- a/slicing.py
+++ b/slicing.py
@@ -9,6 +9,14 @@ class SliceMaker(object):
 makeSlice = SliceMaker()
 
 
+class SlicedGetter(object):
+    def __init__(self, functionReturningArray):
+        self._functionReturningArray = functionReturningArray
+
+    def __getitem__(self, item):
+        return self._functionReturningArray(item)
+
+
 def normalizeSlice(slices, sizes):
     """Converts slices with floating point and/or negative entries to integer slices"""
 
-- 
GitLab