diff --git a/datahandling/datahandling_interface.py b/datahandling/datahandling_interface.py index b8eec497153981912404f6e61a1baed47d0c3d52..99ceddd42e57139abba8000cb2a345eb16694454 100644 --- a/datahandling/datahandling_interface.py +++ b/datahandling/datahandling_interface.py @@ -300,6 +300,17 @@ class DataHandling(ABC): return self.reduce_float_sequence([result], 'max', all_reduce=True)[0] if reduce else result + def save_all(self, file): + """Saves all field data to disk into a file""" + + def load_all(self, file): + """Loads all field data from disk into a file + + Works only if save_all was called with exactly the same field sizes, layouts etc. + When run in parallel save and load has to be called with the same number of processes. + Use for check pointing only - to store results use VTK output + """ + def __str__(self): result = "" diff --git a/datahandling/parallel_datahandling.py b/datahandling/parallel_datahandling.py index 03b7c2e817dc07541a24c85c52e7643c77b68845..23a6f09c44f613ad57fa1124dd553da45decce80 100644 --- a/datahandling/parallel_datahandling.py +++ b/datahandling/parallel_datahandling.py @@ -1,3 +1,4 @@ +import os import numpy as np import warnings from pystencils import Field @@ -369,3 +370,17 @@ class ParallelDataHandling(DataHandling): @property def world_rank(self): return wlb.mpi.worldRank() + + def save_all(self, directory): + if not os.path.exists(directory): + os.mkdir(directory) + if os.path.isfile(directory): + raise RuntimeError("Trying to save to {}, but file exists already".format(directory)) + + for field_name, data_name in self._field_name_to_cpu_data_name.items(): + self.blocks.writeBlockData(data_name, os.path.join(directory, field_name + ".dat")) + + def load_all(self, directory): + for field_name, data_name in self._field_name_to_cpu_data_name.items(): + self.blocks.readBlockData(data_name, os.path.join(directory, field_name + ".dat")) + diff --git a/datahandling/serial_datahandling.py b/datahandling/serial_datahandling.py index 5ede29e2aad07162779f5d3e8003feb5888790a4..800434f941649775d9d67e2715bd150fbde3b430 100644 --- a/datahandling/serial_datahandling.py +++ b/datahandling/serial_datahandling.py @@ -392,7 +392,7 @@ class SerialDataHandling(DataHandling): continue if file_contents[arr_name].shape != arr_contents.shape: print("Skipping read data {} because shapes don't match. " - "Read array shape {}, exising array shape {}".format(arr_name, file_contents[arr_name].shape, - arr_contents.shape)) + "Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape, + arr_contents.shape)) continue np.copyto(arr_contents, file_contents[arr_name])