Commit 7f76698d authored by Markus Holzer's avatar Markus Holzer
Browse files

Added save and load test to datahandling tests

parent d0a06963
...@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling): ...@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
np.savez_compressed(file, **self.cpu_arrays) np.savez_compressed(file, **self.cpu_arrays)
def load_all(self, file): def load_all(self, file):
if '.npz' not in file:
file += '.npz'
file_contents = np.load(file) file_contents = np.load(file)
for arr_name, arr_contents in self.cpu_arrays.items(): for arr_name, arr_contents in self.cpu_arrays.items():
if arr_name not in file_contents: if arr_name not in file_contents:
print(f"Skipping read data {arr_name} because there is no data with this name in data handling") print(f"Skipping read data {arr_name} because there is no data with this name in data handling")
continue continue
if file_contents[arr_name].shape != arr_contents.shape: if file_contents[arr_name].shape != arr_contents.shape:
print("Skipping read data {} because shapes don't match. " print(f"Skipping read data {arr_name} because shapes don't match. "
"Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape, f"Read array shape {file_contents[arr_name].shape}, existing array shape {arr_contents.shape}")
arr_contents.shape))
continue continue
np.copyto(arr_contents, file_contents[arr_name]) np.copyto(arr_contents, file_contents[arr_name])
...@@ -310,3 +310,44 @@ def test_log(): ...@@ -310,3 +310,44 @@ def test_log():
dh.log_on_root() dh.log_on_root()
assert dh.is_root assert dh.is_root
assert dh.world_rank == 0 assert dh.world_rank == 0
def test_save_data():
domain_shape = (2, 2)
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
dh.add_array("src", values_per_cell=9)
dh.fill("src", 1.0, ghost_layers=True)
dh.add_array("dst", values_per_cell=9)
dh.fill("dst", 1.0, ghost_layers=True)
dh.save_all('test_data/datahandling_save_test')
def test_load_data():
domain_shape = (2, 2)
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
dh.add_array("src", values_per_cell=9)
dh.fill("src", 0.0, ghost_layers=True)
dh.add_array("dst", values_per_cell=9)
dh.fill("dst", 0.0, ghost_layers=True)
dh.load_all('test_data/datahandling_load_test')
assert np.all(dh.cpu_arrays['src']) == 1
assert np.all(dh.cpu_arrays['dst']) == 1
domain_shape = (3, 3)
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
dh.add_array("src", values_per_cell=9)
dh.fill("src", 0.0, ghost_layers=True)
dh.add_array("dst", values_per_cell=9)
dh.fill("dst", 0.0, ghost_layers=True)
dh.add_array("dst2", values_per_cell=9)
dh.fill("dst2", 0.0, ghost_layers=True)
dh.load_all('test_data/datahandling_load_test')
assert np.all(dh.cpu_arrays['src']) == 0
assert np.all(dh.cpu_arrays['dst']) == 0
assert np.all(dh.cpu_arrays['dst2']) == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment