From 19f9750d131910429d932b51d8f9097e1ebd4cbd Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 23 Jan 2019 09:43:16 +0100 Subject: [PATCH] Parallel data handling now also supports checkpointing - update waLBerla package in order to make use of this function! --- datahandling/datahandling_interface.py | 11 +++++++++++ datahandling/parallel_datahandling.py | 15 +++++++++++++++ datahandling/serial_datahandling.py | 4 ++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/datahandling/datahandling_interface.py b/datahandling/datahandling_interface.py index b8eec4971..99ceddd42 100644 --- a/datahandling/datahandling_interface.py +++ b/datahandling/datahandling_interface.py @@ -300,6 +300,17 @@ class DataHandling(ABC): return self.reduce_float_sequence([result], 'max', all_reduce=True)[0] if reduce else result + def save_all(self, file): + """Saves all field data to disk into a file""" + + def load_all(self, file): + """Loads all field data from disk into a file + + Works only if save_all was called with exactly the same field sizes, layouts etc. + When run in parallel save and load has to be called with the same number of processes. + Use for check pointing only - to store results use VTK output + """ + def __str__(self): result = "" diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py index 03b7c2e81..23a6f09c4 100644 --- a/datahandling/parallel_datahandling.py +++ b/datahandling/parallel_datahandling.py @@ -1,3 +1,4 @@ +import os import numpy as np import warnings from pystencils import Field @@ -369,3 +370,17 @@ class ParallelDataHandling(DataHandling): @property def world_rank(self): return wlb.mpi.worldRank() + + def save_all(self, directory): + if not os.path.exists(directory): + os.mkdir(directory) + if os.path.isfile(directory): + raise RuntimeError("Trying to save to {}, but file exists already".format(directory)) + + for field_name, data_name in self._field_name_to_cpu_data_name.items(): + self.blocks.writeBlockData(data_name, os.path.join(directory, field_name + ".dat")) + + def load_all(self, directory): + for field_name, data_name in self._field_name_to_cpu_data_name.items(): + self.blocks.readBlockData(data_name, os.path.join(directory, field_name + ".dat")) + diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py index 5ede29e2a..800434f94 100644 --- a/datahandling/serial_datahandling.py +++ b/datahandling/serial_datahandling.py @@ -392,7 +392,7 @@ class SerialDataHandling(DataHandling): continue if file_contents[arr_name].shape != arr_contents.shape: print("Skipping read data {} because shapes don't match. " - "Read array shape {}, exising array shape {}".format(arr_name, file_contents[arr_name].shape, - arr_contents.shape)) + "Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape, + arr_contents.shape)) continue np.copyto(arr_contents, file_contents[arr_name]) -- GitLab