From 7f76698d227f1b98b8c1bf0f8859a1b9243d1fa3 Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Sun, 9 Aug 2020 09:42:42 +0200 Subject: [PATCH] Added save and load test to datahandling tests --- .../datahandling/serial_datahandling.py | 7 +-- .../test_data/datahandling_load_test.npz | Bin 0 -> 410 bytes .../test_data/datahandling_save_test.npz | Bin 0 -> 410 bytes pystencils_tests/test_datahandling.py | 41 ++++++++++++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 pystencils_tests/test_data/datahandling_load_test.npz create mode 100644 pystencils_tests/test_data/datahandling_save_test.npz diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py index 25a4d23ee..ce4629f6a 100644 --- a/pystencils/datahandling/serial_datahandling.py +++ b/pystencils/datahandling/serial_datahandling.py @@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling): np.savez_compressed(file, **self.cpu_arrays) def load_all(self, file): + if '.npz' not in file: + file += '.npz' file_contents = np.load(file) for arr_name, arr_contents in self.cpu_arrays.items(): if arr_name not in file_contents: print(f"Skipping read data {arr_name} because there is no data with this name in data handling") continue if file_contents[arr_name].shape != arr_contents.shape: - print("Skipping read data {} because shapes don't match. " - "Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape, - arr_contents.shape)) + print(f"Skipping read data {arr_name} because shapes don't match. " + f"Read array shape {file_contents[arr_name].shape}, existing array shape {arr_contents.shape}") continue np.copyto(arr_contents, file_contents[arr_name]) diff --git a/pystencils_tests/test_data/datahandling_load_test.npz b/pystencils_tests/test_data/datahandling_load_test.npz new file mode 100644 index 0000000000000000000000000000000000000000..d363a8a0aba1bb78a06314a19b887eb4c4975334 GIT binary patch literal 410 zcmWIWW@Zs#U|`??Vnv4TVm_%5Ad7*Ofq|VtgrT@7Sud}kl953GECiAPO9ScIZ^U0o z3!FR=a4cZ$yh%}WVwU7BU6409ZQ;7b3+7FW4+)wwLwtVxlu2Ad{F++6tX$&hDq>5R zc1o#PaXF-{T)8-4wS(G&B!*`GZ;QWZ*n0I}`m&5M0Iy?Gic9G07)B-$W?W$d3JM5l iU<A?7kP7f7R#Puf6Vyim-mGjOGnjxd3rI_WO#}cXWnnh} literal 0 HcmV?d00001 diff --git a/pystencils_tests/test_data/datahandling_save_test.npz b/pystencils_tests/test_data/datahandling_save_test.npz new file mode 100644 index 0000000000000000000000000000000000000000..d363a8a0aba1bb78a06314a19b887eb4c4975334 GIT binary patch literal 410 zcmWIWW@Zs#U|`??Vnv4TVm_%5Ad7*Ofq|VtgrT@7Sud}kl953GECiAPO9ScIZ^U0o z3!FR=a4cZ$yh%}WVwU7BU6409ZQ;7b3+7FW4+)wwLwtVxlu2Ad{F++6tX$&hDq>5R zc1o#PaXF-{T)8-4wS(G&B!*`GZ;QWZ*n0I}`m&5M0Iy?Gic9G07)B-$W?W$d3JM5l iU<A?7kP7f7R#Puf6Vyim-mGjOGnjxd3rI_WO#}cXWnnh} literal 0 HcmV?d00001 diff --git a/pystencils_tests/test_datahandling.py b/pystencils_tests/test_datahandling.py index 6e53d1e8b..c18cfba98 100644 --- a/pystencils_tests/test_datahandling.py +++ b/pystencils_tests/test_datahandling.py @@ -310,3 +310,44 @@ def test_log(): dh.log_on_root() assert dh.is_root assert dh.world_rank == 0 + + +def test_save_data(): + domain_shape = (2, 2) + + dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1) + dh.add_array("src", values_per_cell=9) + dh.fill("src", 1.0, ghost_layers=True) + dh.add_array("dst", values_per_cell=9) + dh.fill("dst", 1.0, ghost_layers=True) + + dh.save_all('test_data/datahandling_save_test') + + +def test_load_data(): + domain_shape = (2, 2) + + dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1) + dh.add_array("src", values_per_cell=9) + dh.fill("src", 0.0, ghost_layers=True) + dh.add_array("dst", values_per_cell=9) + dh.fill("dst", 0.0, ghost_layers=True) + + dh.load_all('test_data/datahandling_load_test') + assert np.all(dh.cpu_arrays['src']) == 1 + assert np.all(dh.cpu_arrays['dst']) == 1 + + domain_shape = (3, 3) + + dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1) + dh.add_array("src", values_per_cell=9) + dh.fill("src", 0.0, ghost_layers=True) + dh.add_array("dst", values_per_cell=9) + dh.fill("dst", 0.0, ghost_layers=True) + dh.add_array("dst2", values_per_cell=9) + dh.fill("dst2", 0.0, ghost_layers=True) + + dh.load_all('test_data/datahandling_load_test') + assert np.all(dh.cpu_arrays['src']) == 0 + assert np.all(dh.cpu_arrays['dst']) == 0 + assert np.all(dh.cpu_arrays['dst2']) == 0 -- GitLab