serial_datahandling.py 18.7 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import itertools
Martin Bauer's avatar
Martin Bauer committed
2
import time
Martin Bauer's avatar
Martin Bauer committed
3
from typing import Sequence, Union
Martin Bauer's avatar
Martin Bauer committed
4

5
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
6

7
from pystencils.datahandling.blockiteration import SerialBlock
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.datahandling.datahandling_interface import DataHandling
9
from pystencils.datahandling.pycuda import PyCudaArrayHandler, PyCudaNotAvailableHandler
10
from pystencils.datahandling.pyopencl import PyOpenClArrayHandler
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.field import (
12
13
    Field, FieldType, create_numpy_array_with_layout, layout_string_to_tuple,
    spatial_layout_string_to_tuple)
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.slicing import normalize_slice, remove_ghost_layers
Martin Bauer's avatar
Martin Bauer committed
15
16
from pystencils.utils import DotDict

17

Martin Bauer's avatar
Martin Bauer committed
18
19
class SerialDataHandling(DataHandling):

20
21
22
23
24
25
26
    def __init__(self,
                 domain_size: Sequence[int],
                 default_ghost_layers: int = 1,
                 default_layout: str = 'SoA',
                 periodicity: Union[bool, Sequence[bool]] = False,
                 default_target: str = 'cpu',
                 opencl_queue=None,
27
                 opencl_ctx=None,
28
                 array_handler=None) -> None:
Martin Bauer's avatar
Martin Bauer committed
29
        """
Martin Bauer's avatar
Martin Bauer committed
30
31
32
33
34
35
36
37
        Creates a data handling for single node simulations.

        Args:
            domain_size: size of the spatial domain as tuple
            default_ghost_layers: default number of ghost layers used, if not overridden in add_array() method
            default_layout: default layout used, if  not overridden in add_array() method
            default_target: either 'cpu' or 'gpu' . If set to 'gpu' for each array also a GPU version is allocated
                            if not overwritten in add_array, and synchronization functions are for the GPU by default
Martin Bauer's avatar
Martin Bauer committed
38
        """
39
        super(SerialDataHandling, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
40
        self._domainSize = tuple(domain_size)
Martin Bauer's avatar
Martin Bauer committed
41
42
        self.default_ghost_layers = default_ghost_layers
        self.default_layout = default_layout
Martin Bauer's avatar
Martin Bauer committed
43
        self._fields = DotDict()
Martin Bauer's avatar
Martin Bauer committed
44
45
46
47
48
        self.cpu_arrays = DotDict()
        self.gpu_arrays = DotDict()
        self.custom_data_cpu = DotDict()
        self.custom_data_gpu = DotDict()
        self._custom_data_transfer_functions = {}
49
50
        self._opencl_queue = opencl_queue
        self._opencl_ctx = opencl_ctx
51

52
        if not array_handler:
53
54
55
            try:
                self.array_handler = PyCudaArrayHandler()
            except Exception:
56
                self.array_handler = PyCudaNotAvailableHandler()
57
58
59

            if default_target == 'opencl' or opencl_queue:
                self.array_handler = PyOpenClArrayHandler(opencl_queue)
60
61
        else:
            self.array_handler = array_handler
62

63
64
65
66
67
68
        if periodicity is None or periodicity is False:
            periodicity = [False] * self.dim
        if periodicity is True:
            periodicity = [True] * self.dim

        self._periodicity = periodicity
Martin Bauer's avatar
Martin Bauer committed
69
        self._field_information = {}
Martin Bauer's avatar
Martin Bauer committed
70
        self.default_target = default_target
71
        self._start_time = time.perf_counter()
Martin Bauer's avatar
Martin Bauer committed
72
73
74
75
76

    @property
    def dim(self):
        return len(self._domainSize)

Martin Bauer's avatar
Martin Bauer committed
77
78
79
80
81
82
83
84
    @property
    def shape(self):
        return self._domainSize

    @property
    def periodicity(self):
        return self._periodicity

Martin Bauer's avatar
Martin Bauer committed
85
86
87
88
    @property
    def fields(self):
        return self._fields

Martin Bauer's avatar
Martin Bauer committed
89
90
    def ghost_layers_of_field(self, name):
        return self._field_information[name]['ghost_layers']
91

Martin Bauer's avatar
Martin Bauer committed
92
93
    def values_per_cell(self, name):
        return self._field_information[name]['values_per_cell']
94

Martin Bauer's avatar
Martin Bauer committed
95
    def add_array(self, name, values_per_cell=1, dtype=np.float64, latex_name=None, ghost_layers=None, layout=None,
96
                  cpu=True, gpu=None, alignment=False, field_type=FieldType.GENERIC):
Martin Bauer's avatar
Martin Bauer committed
97
        if ghost_layers is None:
Martin Bauer's avatar
Martin Bauer committed
98
            ghost_layers = self.default_ghost_layers
Martin Bauer's avatar
Martin Bauer committed
99
        if layout is None:
Martin Bauer's avatar
Martin Bauer committed
100
            layout = self.default_layout
101
        if gpu is None:
102
            gpu = self.default_target in self._GPU_LIKE_TARGETS
Martin Bauer's avatar
Martin Bauer committed
103
104

        kwargs = {
Martin Bauer's avatar
Martin Bauer committed
105
            'shape': tuple(s + 2 * ghost_layers for s in self._domainSize),
Martin Bauer's avatar
Martin Bauer committed
106
107
            'dtype': dtype,
        }
108
109
110
111
112
113

        if not hasattr(values_per_cell, '__len__'):
            values_per_cell = (values_per_cell, )
        if len(values_per_cell) == 1 and values_per_cell[0] == 1:
            values_per_cell = ()

Martin Bauer's avatar
Martin Bauer committed
114
115
116
        self._field_information[name] = {
            'ghost_layers': ghost_layers,
            'values_per_cell': values_per_cell,
Martin Bauer's avatar
Martin Bauer committed
117
118
            'layout': layout,
            'dtype': dtype,
Martin Bauer's avatar
Martin Bauer committed
119
            'alignment': alignment,
120
            'field_type': field_type,
Martin Bauer's avatar
Martin Bauer committed
121
122
        }

123
124
125
126
127
        index_dimensions = len(values_per_cell)
        kwargs['shape'] = kwargs['shape'] + values_per_cell

        if index_dimensions > 0:
            layout_tuple = layout_string_to_tuple(layout, self.dim + index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
128
        else:
Martin Bauer's avatar
Martin Bauer committed
129
            layout_tuple = spatial_layout_string_to_tuple(layout, self.dim)
130

Martin Bauer's avatar
Martin Bauer committed
131
        # cpu_arr is always created - since there is no create_pycuda_array_with_layout()
Martin Bauer's avatar
Martin Bauer committed
132
133
134
        byte_offset = ghost_layers * np.dtype(dtype).itemsize
        cpu_arr = create_numpy_array_with_layout(layout=layout_tuple, alignment=alignment,
                                                 byte_offset=byte_offset, **kwargs)
135

Martin Bauer's avatar
Martin Bauer committed
136
137
138
        if alignment and gpu:
            raise NotImplementedError("Alignment for GPU fields not supported")

Martin Bauer's avatar
Martin Bauer committed
139
        if cpu:
Martin Bauer's avatar
Martin Bauer committed
140
            if name in self.cpu_arrays:
Martin Bauer's avatar
Martin Bauer committed
141
                raise ValueError("CPU Field with this name already exists")
Martin Bauer's avatar
Martin Bauer committed
142
            self.cpu_arrays[name] = cpu_arr
Martin Bauer's avatar
Martin Bauer committed
143
        if gpu:
Martin Bauer's avatar
Martin Bauer committed
144
            if name in self.gpu_arrays:
Martin Bauer's avatar
Martin Bauer committed
145
                raise ValueError("GPU Field with this name already exists")
146
            self.gpu_arrays[name] = self.array_handler.to_gpu(cpu_arr)
Martin Bauer's avatar
Martin Bauer committed
147

Martin Bauer's avatar
Martin Bauer committed
148
        assert all(f.name != name for f in self.fields.values()), "Symbolic field with this name already exists"
Michael Kuron's avatar
Michael Kuron committed
149
        self.fields[name] = Field.create_from_numpy_array(name, cpu_arr, index_dimensions=index_dimensions,
150
                                                          field_type=field_type)
Martin Bauer's avatar
Martin Bauer committed
151
        self.fields[name].latex_name = latex_name
152
        return self.fields[name]
Martin Bauer's avatar
Martin Bauer committed
153

Martin Bauer's avatar
Martin Bauer committed
154
155
    def add_custom_data(self, name, cpu_creation_function,
                        gpu_creation_function=None, cpu_to_gpu_transfer_func=None, gpu_to_cpu_transfer_func=None):
156

Martin Bauer's avatar
Martin Bauer committed
157
158
        if cpu_creation_function and gpu_creation_function:
            if cpu_to_gpu_transfer_func is None or gpu_to_cpu_transfer_func is None:
159
                raise ValueError("For GPU data, both transfer functions have to be specified")
Martin Bauer's avatar
Martin Bauer committed
160
            self._custom_data_transfer_functions[name] = (cpu_to_gpu_transfer_func, gpu_to_cpu_transfer_func)
161

Martin Bauer's avatar
Martin Bauer committed
162
163
164
165
        assert name not in self.custom_data_cpu
        if cpu_creation_function:
            assert name not in self.cpu_arrays
            self.custom_data_cpu[name] = cpu_creation_function()
166

Martin Bauer's avatar
Martin Bauer committed
167
168
169
        if gpu_creation_function:
            assert name not in self.gpu_arrays
            self.custom_data_gpu[name] = gpu_creation_function()
170

Martin Bauer's avatar
Martin Bauer committed
171
    def has_data(self, name):
Martin Bauer's avatar
Martin Bauer committed
172
        return name in self.fields
173

Martin Bauer's avatar
Martin Bauer committed
174
175
176
177
178
179
    def add_array_like(self, name, name_of_template_field, latex_name=None, cpu=True, gpu=None):
        return self.add_array(name, latex_name=latex_name, cpu=cpu, gpu=gpu,
                              **self._field_information[name_of_template_field])

    def iterate(self, slice_obj=None, gpu=False, ghost_layers=True, inner_ghost_layers=True):
        if ghost_layers is True:
Martin Bauer's avatar
Martin Bauer committed
180
            ghost_layers = self.default_ghost_layers
Martin Bauer's avatar
Martin Bauer committed
181
182
183
184
185
186
187
188
189
190
191
192
193
        elif ghost_layers is False:
            ghost_layers = 0
        elif isinstance(ghost_layers, str):
            ghost_layers = self.ghost_layers_of_field(ghost_layers)

        if slice_obj is None:
            slice_obj = (slice(None, None, None),) * self.dim
        slice_obj = normalize_slice(slice_obj, tuple(s + 2 * ghost_layers for s in self._domainSize))
        slice_obj = tuple(s if type(s) is slice else slice(s, s + 1, None) for s in slice_obj)

        arrays = self.gpu_arrays if gpu else self.cpu_arrays
        custom_data_dict = self.custom_data_gpu if gpu else self.custom_data_cpu
        iter_dict = custom_data_dict.copy()
194
        for name, arr in arrays.items():
Martin Bauer's avatar
Martin Bauer committed
195
196
            field_gls = self._field_information[name]['ghost_layers']
            if field_gls < ghost_layers:
197
                continue
Martin Bauer's avatar
Martin Bauer committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            arr = remove_ghost_layers(arr, index_dimensions=len(arr.shape) - self.dim,
                                      ghost_layers=field_gls - ghost_layers)
            iter_dict[name] = arr

        offset = tuple(s.start - ghost_layers for s in slice_obj)
        yield SerialBlock(iter_dict, offset, slice_obj)

    def gather_array(self, name, slice_obj=None, ghost_layers=False, **kwargs):
        gl_to_remove = self._field_information[name]['ghost_layers']
        if isinstance(ghost_layers, int):
            gl_to_remove -= ghost_layers
        if ghost_layers is True:
            gl_to_remove = 0
        arr = self.cpu_arrays[name]
        ind_dimensions = self.fields[name].index_dimensions
        spatial_dimensions = self.fields[name].spatial_dimensions

        arr = remove_ghost_layers(arr, index_dimensions=ind_dimensions, ghost_layers=gl_to_remove)

        if slice_obj is not None:
            normalized_slice = normalize_slice(slice_obj[:spatial_dimensions], arr.shape[:spatial_dimensions])
            normalized_slice = tuple(s if type(s) is slice else slice(s, s + 1, None) for s in normalized_slice)
            normalized_slice += slice_obj[spatial_dimensions:]
            arr = arr[normalized_slice]
222
223
224
225
        else:
            arr = arr.view()
        arr.flags.writeable = False
        return arr
Martin Bauer's avatar
Martin Bauer committed
226

227
228
    def swap(self, name1, name2, gpu=None):
        if gpu is None:
229
            gpu = self.default_target in self._GPU_LIKE_TARGETS
230
231
        arr = self.gpu_arrays if gpu else self.cpu_arrays
        arr[name1], arr[name2] = arr[name2], arr[name1]
Martin Bauer's avatar
Martin Bauer committed
232
233
234
235
236
237
238
239
240

    def all_to_cpu(self):
        for name in (self.cpu_arrays.keys() & self.gpu_arrays.keys()) | self._custom_data_transfer_functions.keys():
            self.to_cpu(name)

    def all_to_gpu(self):
        for name in (self.cpu_arrays.keys() & self.gpu_arrays.keys()) | self._custom_data_transfer_functions.keys():
            self.to_gpu(name)

Martin Bauer's avatar
Martin Bauer committed
241
    def run_kernel(self, kernel_function, **kwargs):
242
        arrays = self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays
243
        kernel_function(**{**arrays, **kwargs})
Martin Bauer's avatar
Martin Bauer committed
244

245
246
    def get_kernel_kwargs(self, kernel_function, **kwargs):
        result = {}
247
        result.update(self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays)
248
249
250
        result.update(kwargs)
        return [result]

Martin Bauer's avatar
Martin Bauer committed
251
252
253
254
    def to_cpu(self, name):
        if name in self._custom_data_transfer_functions:
            transfer_func = self._custom_data_transfer_functions[name][1]
            transfer_func(self.custom_data_gpu[name], self.custom_data_cpu[name])
255
        else:
256
            self.array_handler.download(self.gpu_arrays[name], self.cpu_arrays[name])
Martin Bauer's avatar
Martin Bauer committed
257

Martin Bauer's avatar
Martin Bauer committed
258
259
260
261
    def to_gpu(self, name):
        if name in self._custom_data_transfer_functions:
            transfer_func = self._custom_data_transfer_functions[name][0]
            transfer_func(self.custom_data_gpu[name], self.custom_data_cpu[name])
262
        else:
263
            self.array_handler.upload(self.gpu_arrays[name], self.cpu_arrays[name])
264

Martin Bauer's avatar
Martin Bauer committed
265
266
    def is_on_gpu(self, name):
        return name in self.gpu_arrays
267

Martin Bauer's avatar
Martin Bauer committed
268
    def synchronization_function_cpu(self, names, stencil_name=None, **_):
269
        return self.synchronization_function(names, stencil_name, target='cpu')
270

Martin Bauer's avatar
Martin Bauer committed
271
    def synchronization_function_gpu(self, names, stencil_name=None, **_):
272
        return self.synchronization_function(names, stencil_name, target='gpu')
273

Martin Bauer's avatar
Martin Bauer committed
274
    def synchronization_function(self, names, stencil=None, target=None, **_):
275
        if target is None:
Martin Bauer's avatar
Martin Bauer committed
276
            target = self.default_target
277
278
        if target == 'opencl':
            target = 'gpu'
279
280
281
282
        assert target in ('cpu', 'gpu')
        if not hasattr(names, '__len__') or type(names) is str:
            names = [names]

Martin Bauer's avatar
Martin Bauer committed
283
        filtered_stencil = []
Martin Bauer's avatar
Martin Bauer committed
284
285
        neighbors = [-1, 0, 1]

286
        if (stencil is None and self.dim == 2) or (stencil is not None and stencil.startswith('D2')):
Martin Bauer's avatar
Martin Bauer committed
287
            directions = itertools.product(*[neighbors] * 2)
288
        elif (stencil is None and self.dim == 3) or (stencil is not None and stencil.startswith('D3')):
Martin Bauer's avatar
Martin Bauer committed
289
290
291
292
293
            directions = itertools.product(*[neighbors] * 3)
        else:
            raise ValueError("Invalid stencil")

        for direction in directions:
Martin Bauer's avatar
Martin Bauer committed
294
            use_direction = True
295
            if direction == (0, 0) or direction == (0, 0, 0):
Martin Bauer's avatar
Martin Bauer committed
296
                use_direction = False
297
298
            for component, periodicity in zip(direction, self._periodicity):
                if not periodicity and component != 0:
Martin Bauer's avatar
Martin Bauer committed
299
300
301
                    use_direction = False
            if use_direction:
                filtered_stencil.append(direction)
302

Martin Bauer's avatar
Martin Bauer committed
303
        result = []
304
        for name in names:
Martin Bauer's avatar
Martin Bauer committed
305
            gls = self._field_information[name]['ghost_layers']
306
307
308
309
310
311
            values_per_cell = self._field_information[name]['values_per_cell']
            if values_per_cell == ():
                values_per_cell = (1, )
            if len(values_per_cell) == 1:
                values_per_cell = values_per_cell[0]

Martin Bauer's avatar
Martin Bauer committed
312
            if len(filtered_stencil) > 0:
313
                if target == 'cpu':
Martin Bauer's avatar
Martin Bauer committed
314
315
                    from pystencils.slicing import get_periodic_boundary_functor
                    result.append(get_periodic_boundary_functor(filtered_stencil, ghost_layers=gls))
316
                else:
Martin Bauer's avatar
Martin Bauer committed
317
                    from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as boundary_func
318
                    target = 'gpu' if not isinstance(self.array_handler, PyOpenClArrayHandler) else 'opencl'
Martin Bauer's avatar
Martin Bauer committed
319
320
                    result.append(boundary_func(filtered_stencil, self._domainSize,
                                                index_dimensions=self.fields[name].index_dimensions,
321
                                                index_dim_shape=values_per_cell,
Martin Bauer's avatar
Martin Bauer committed
322
                                                dtype=self.fields[name].dtype.numpy_dtype,
323
324
325
326
                                                ghost_layers=gls,
                                                target=target,
                                                opencl_queue=self._opencl_queue,
                                                opencl_ctx=self._opencl_ctx))
327
328

        if target == 'cpu':
Martin Bauer's avatar
Martin Bauer committed
329
330
331
            def result_functor():
                for arr_name, func in zip(names, result):
                    func(pdfs=self.cpu_arrays[arr_name])
332
        else:
Martin Bauer's avatar
Martin Bauer committed
333
334
335
            def result_functor():
                for arr_name, func in zip(names, result):
                    func(pdfs=self.gpu_arrays[arr_name])
336

Martin Bauer's avatar
Martin Bauer committed
337
        return result_functor
Martin Bauer's avatar
Martin Bauer committed
338

339
    @property
Martin Bauer's avatar
Martin Bauer committed
340
    def array_names(self):
341
342
343
        return tuple(self.fields.keys())

    @property
Martin Bauer's avatar
Martin Bauer committed
344
345
    def custom_data_names(self):
        return tuple(self.custom_data_cpu.keys())
346

Martin Bauer's avatar
Martin Bauer committed
347
    def reduce_float_sequence(self, sequence, operation, all_reduce=False) -> np.array:
Martin Bauer's avatar
Martin Bauer committed
348
349
        return np.array(sequence)

Martin Bauer's avatar
Martin Bauer committed
350
    def reduce_int_sequence(self, sequence, operation, all_reduce=False) -> np.array:
Martin Bauer's avatar
Martin Bauer committed
351
352
        return np.array(sequence)

Martin Bauer's avatar
Martin Bauer committed
353
    def create_vtk_writer(self, file_name, data_names, ghost_layers=False):
354
        from pystencils.datahandling.vtk import image_to_vtk
Martin Bauer's avatar
Martin Bauer committed
355
356

        def writer(step):
Martin Bauer's avatar
Martin Bauer committed
357
358
359
360
            full_file_name = "%s_%08d" % (file_name, step,)
            cell_data = {}
            for name in data_names:
                field = self._get_field_with_given_number_of_ghost_layers(name, ghost_layers)
Martin Bauer's avatar
Martin Bauer committed
361
                if self.dim == 2:
362
                    field = field[:, :, np.newaxis]
Martin Bauer's avatar
Martin Bauer committed
363
                if len(field.shape) == 3:
Martin Bauer's avatar
Martin Bauer committed
364
                    cell_data[name] = np.ascontiguousarray(field)
Martin Bauer's avatar
Martin Bauer committed
365
                elif len(field.shape) == 4:
Martin Bauer's avatar
Martin Bauer committed
366
367
368
                    values_per_cell = field.shape[-1]
                    if values_per_cell == self.dim:
                        field = [np.ascontiguousarray(field[..., i]) for i in range(values_per_cell)]
Martin Bauer's avatar
Martin Bauer committed
369
370
                        if len(field) == 2:
                            field.append(np.zeros_like(field[0]))
Martin Bauer's avatar
Martin Bauer committed
371
                        cell_data[name] = tuple(field)
Martin Bauer's avatar
Martin Bauer committed
372
                    else:
Martin Bauer's avatar
Martin Bauer committed
373
374
                        for i in range(values_per_cell):
                            cell_data["%s[%d]" % (name, i)] = np.ascontiguousarray(field[..., i])
Martin Bauer's avatar
Martin Bauer committed
375
                else:
376
377
                    raise NotImplementedError("VTK export for fields with more than one index "
                                              "coordinate not implemented")
Martin Bauer's avatar
Martin Bauer committed
378
            image_to_vtk(full_file_name, cell_data=cell_data)
Martin Bauer's avatar
Martin Bauer committed
379
380
        return writer

Martin Bauer's avatar
Martin Bauer committed
381
    def create_vtk_writer_for_flag_array(self, file_name, data_name, masks_to_name, ghost_layers=False):
382
        from pystencils.datahandling.vtk import image_to_vtk
Martin Bauer's avatar
Martin Bauer committed
383
384

        def writer(step):
Martin Bauer's avatar
Martin Bauer committed
385
386
            full_file_name = "%s_%08d" % (file_name, step,)
            field = self._get_field_with_given_number_of_ghost_layers(data_name, ghost_layers)
Martin Bauer's avatar
Martin Bauer committed
387
388
            if self.dim == 2:
                field = field[:, :, np.newaxis]
389
            cell_data = {name: np.ascontiguousarray(np.bitwise_and(field, field.dtype.type(mask)) > 0, dtype=np.uint8)
Martin Bauer's avatar
Martin Bauer committed
390
391
                         for mask, name in masks_to_name.items()}
            image_to_vtk(full_file_name, cell_data=cell_data)
Martin Bauer's avatar
Martin Bauer committed
392
393
394

        return writer

Martin Bauer's avatar
Martin Bauer committed
395
396
397
398
    def _get_field_with_given_number_of_ghost_layers(self, name, ghost_layers):
        actual_ghost_layers = self.ghost_layers_of_field(name)
        if ghost_layers is True:
            ghost_layers = actual_ghost_layers
399

Martin Bauer's avatar
Martin Bauer committed
400
        gl_to_remove = actual_ghost_layers - ghost_layers
401
        ind_dims = len(self._field_information[name]['values_per_cell'])
Martin Bauer's avatar
Martin Bauer committed
402
        return remove_ghost_layers(self.cpu_arrays[name], ind_dims, gl_to_remove)
403
404
405
406
407
408
409

    def log(self, *args, level='INFO'):
        level = level.upper()
        message = " ".join(str(e) for e in args)

        time_running = time.perf_counter() - self._start_time
        spacing = 7 - len(str(int(time_running)))
410
        message = f"[{level: <8}]{spacing * '-'}({time_running:.3f} sec) {message} "
411
412
413
414
415
416
417
418
419
420
421
422
        print(message, flush=True)

    def log_on_root(self, *args, level='INFO'):
        self.log(*args, level=level)

    @property
    def is_root(self):
        return True

    @property
    def world_rank(self):
        return 0
Martin Bauer's avatar
Martin Bauer committed
423
424
425

    def save_all(self, file):
        np.savez_compressed(file, **self.cpu_arrays)
426
427

    def load_all(self, file):
428
429
        if '.npz' not in file:
            file += '.npz'
430
431
432
        file_contents = np.load(file)
        for arr_name, arr_contents in self.cpu_arrays.items():
            if arr_name not in file_contents:
433
                print(f"Skipping read data {arr_name} because there is no data with this name in data handling")
434
435
                continue
            if file_contents[arr_name].shape != arr_contents.shape:
436
437
                print(f"Skipping read data {arr_name} because shapes don't match. "
                      f"Read array shape {file_contents[arr_name].shape}, existing array shape {arr_contents.shape}")
438
439
                continue
            np.copyto(arr_contents, file_contents[arr_name])