From 19f9750d131910429d932b51d8f9097e1ebd4cbd Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 23 Jan 2019 09:43:16 +0100
Subject: [PATCH] Parallel data handling now also supports checkpointing

- update waLBerla package in order to make use of this function!
---
 datahandling/datahandling_interface.py | 11 +++++++++++
 datahandling/parallel_datahandling.py  | 15 +++++++++++++++
 datahandling/serial_datahandling.py    |  4 ++--
 3 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/datahandling/datahandling_interface.py b/datahandling/datahandling_interface.py
index b8eec4971..99ceddd42 100644
--- a/datahandling/datahandling_interface.py
+++ b/datahandling/datahandling_interface.py
@@ -300,6 +300,17 @@ class DataHandling(ABC):
 
         return self.reduce_float_sequence([result], 'max', all_reduce=True)[0] if reduce else result
 
+    def save_all(self, file):
+        """Saves all field data to disk into a file"""
+
+    def load_all(self, file):
+        """Loads all field data from disk into a file
+
+        Works only if save_all was called with exactly the same field sizes, layouts etc.
+        When run in parallel save and load has to be called with the same number of processes.
+        Use for check pointing only - to store results use VTK output
+        """
+
     def __str__(self):
         result = ""
 
diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py
index 03b7c2e81..23a6f09c4 100644
--- a/datahandling/parallel_datahandling.py
+++ b/datahandling/parallel_datahandling.py
@@ -1,3 +1,4 @@
+import os
 import numpy as np
 import warnings
 from pystencils import Field
@@ -369,3 +370,17 @@ class ParallelDataHandling(DataHandling):
     @property
     def world_rank(self):
         return wlb.mpi.worldRank()
+
+    def save_all(self, directory):
+        if not os.path.exists(directory):
+            os.mkdir(directory)
+        if os.path.isfile(directory):
+            raise RuntimeError("Trying to save to {}, but file exists already".format(directory))
+
+        for field_name, data_name in self._field_name_to_cpu_data_name.items():
+            self.blocks.writeBlockData(data_name, os.path.join(directory, field_name + ".dat"))
+
+    def load_all(self, directory):
+        for field_name, data_name in self._field_name_to_cpu_data_name.items():
+            self.blocks.readBlockData(data_name, os.path.join(directory, field_name + ".dat"))
+
diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py
index 5ede29e2a..800434f94 100644
--- a/datahandling/serial_datahandling.py
+++ b/datahandling/serial_datahandling.py
@@ -392,7 +392,7 @@ class SerialDataHandling(DataHandling):
                 continue
             if file_contents[arr_name].shape != arr_contents.shape:
                 print("Skipping read data {} because shapes don't match. "
-                      "Read array shape {}, exising array shape {}".format(arr_name, file_contents[arr_name].shape,
-                                                                           arr_contents.shape))
+                      "Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape,
+                                                                            arr_contents.shape))
                 continue
             np.copyto(arr_contents, file_contents[arr_name])
-- 
GitLab