Skip to content
Snippets Groups Projects
Commit 05157881 authored by Markus Holzer's avatar Markus Holzer Committed by Michael Kuron
Browse files

Update parallel datahandling

parent ff05f9d8
Branches
Tags
No related merge requests found
...@@ -15,3 +15,6 @@ RELEASE-VERSION ...@@ -15,3 +15,6 @@ RELEASE-VERSION
test-report test-report
pystencils/boundaries/createindexlistcython.c pystencils/boundaries/createindexlistcython.c
pystencils/boundaries/createindexlistcython.*.so pystencils/boundaries/createindexlistcython.*.so
pystencils_tests/tmp
pystencils_tests/kerncraft_inputs/.2d-5pt.c_kerncraft/
pystencils_tests/kerncraft_inputs/.3d-7pt.c_kerncraft/
\ No newline at end of file
This diff is collapsed.
...@@ -111,15 +111,15 @@ class ParallelBlock(Block): ...@@ -111,15 +111,15 @@ class ParallelBlock(Block):
def __getitem__(self, data_name): def __getitem__(self, data_name):
result = self._block[self._name_prefix + data_name] result = self._block[self._name_prefix + data_name]
type_name = type(result).__name__ type_name = type(result).__name__
if type_name == 'GhostLayerField': if 'GhostLayerField' in type_name:
result = wlb.field.toArray(result, withGhostLayers=self._gls) result = wlb.field.toArray(result, with_ghost_layers=self._gls)
result = self._normalize_array_shape(result) result = self._normalize_array_shape(result)
elif type_name == 'GpuField': elif 'GpuField' in type_name:
result = wlb.cuda.toGpuArray(result, withGhostLayers=self._gls) result = wlb.cuda.toGpuArray(result, with_ghost_layers=self._gls)
result = self._normalize_array_shape(result) result = self._normalize_array_shape(result)
return result return result
def _normalize_array_shape(self, arr): def _normalize_array_shape(self, arr):
if arr.shape[-1] == 1: if arr.shape[-1] == 1 and len(arr.shape) == 4:
arr = arr[..., 0] arr = arr[..., 0]
return arr[self._localSlice] return arr[self._localSlice]
...@@ -101,7 +101,7 @@ class ParallelDataHandling(DataHandling): ...@@ -101,7 +101,7 @@ class ParallelDataHandling(DataHandling):
raise ValueError("Data handling expects that each process has at least one block") raise ValueError("Data handling expects that each process has at least one block")
if hasattr(dtype, 'type'): if hasattr(dtype, 'type'):
dtype = dtype.type dtype = dtype.type
if name in self.blocks[0] or self.GPU_DATA_PREFIX + name in self.blocks[0]: if name in self.blocks[0].fieldNames or self.GPU_DATA_PREFIX + name in self.blocks[0].fieldNames:
raise ValueError("Data with this name has already been added") raise ValueError("Data with this name has already been added")
if alignment is False or alignment is None: if alignment is False or alignment is None:
...@@ -215,15 +215,13 @@ class ParallelDataHandling(DataHandling): ...@@ -215,15 +215,13 @@ class ParallelDataHandling(DataHandling):
array = array[:, :, 0] array = array[:, :, 0]
if last_element and self.fields[name].index_dimensions > 0: if last_element and self.fields[name].index_dimensions > 0:
array = array[..., last_element[0]] array = array[..., last_element[0]]
if self.fields[name].index_dimensions == 0:
array = array[..., 0]
return array return array
def _normalize_arr_shape(self, arr, index_dimensions): def _normalize_arr_shape(self, arr, index_dimensions):
if index_dimensions == 0: if index_dimensions == 0 and len(arr.shape) > 3:
arr = arr[..., 0] arr = arr[..., 0]
if self.dim == 2: if self.dim == 2 and len(arr.shape) > 2:
arr = arr[:, :, 0] arr = arr[:, :, 0]
return arr return arr
...@@ -246,7 +244,7 @@ class ParallelDataHandling(DataHandling): ...@@ -246,7 +244,7 @@ class ParallelDataHandling(DataHandling):
for block in self.blocks: for block in self.blocks:
field_args = {} field_args = {}
for data_name, f in data_used_in_kernel: for data_name, f in data_used_in_kernel:
arr = to_array(block[data_name], withGhostLayers=[True, True, self.dim == 3]) arr = to_array(block[data_name], with_ghost_layers=[True, True, self.dim == 3])
arr = self._normalize_arr_shape(arr, f.index_dimensions) arr = self._normalize_arr_shape(arr, f.index_dimensions)
field_args[f.name] = arr field_args[f.name] = arr
field_args.update(kwargs) field_args.update(kwargs)
......
This diff is collapsed.
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment