diff --git a/pystencils/datahandling/parallel_datahandling.py b/pystencils/datahandling/parallel_datahandling.py index 54f26806be318f6ef91a5ca11a9888a59524fb0c..535933300f73977d3ede92e714f2a381be92fa50 100644 --- a/pystencils/datahandling/parallel_datahandling.py +++ b/pystencils/datahandling/parallel_datahandling.py @@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling): if hasattr(values_per_cell, '__len__'): raise NotImplementedError("Parallel data handling does not support multiple index dimensions") - self._fieldInformation[name] = {'ghost_layers': ghost_layers, - 'values_per_cell': values_per_cell, - 'layout': layout, - 'dtype': dtype, - 'alignment': alignment} + self._fieldInformation[name] = { + 'ghost_layers': ghost_layers, + 'values_per_cell': values_per_cell, + 'layout': layout, + 'dtype': dtype, + 'alignment': alignment, + 'field_type': field_type, + } layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf, 'f': wlb.field.Layout.fzyx, diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py index f8b0a4a1d8b56e8533eb99c6c7f915c00f307419..697b6b6674a83b6786c8a6a3f687df3e49cd8a01 100644 --- a/pystencils/datahandling/serial_datahandling.py +++ b/pystencils/datahandling/serial_datahandling.py @@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling): 'layout': layout, 'dtype': dtype, 'alignment': alignment, + 'field_type': field_type, } index_dimensions = len(values_per_cell) diff --git a/pystencils/field.py b/pystencils/field.py index 0c196f4ba61858f4aa03e8cde0260601d1e68c0d..e57ce95c982797847aba7403d2d25894387a1b75 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -466,6 +466,7 @@ class Field(AbstractField): """ assert FieldType.is_staggered(self) + offset_orig = offset if type(offset) is np.ndarray: offset = tuple(offset) if type(offset) is str: @@ -482,9 +483,13 @@ class Field(AbstractField): for i, o in enumerate(offset): if (o + sp.Rational(1, 2)).is_Integer: offset[i] += sp.Rational(1, 2) - neighbor[i] = 1 + neighbor[i] = -1 neighbor = offset_to_direction_string(neighbor) - idx = self.staggered_stencil.index(neighbor) + try: + idx = self.staggered_stencil.index(neighbor) + except ValueError: + raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, + self.staggered_stencil_name)) offset = tuple(offset) if self.index_dimensions == 1: # this field stores a scalar value at each staggered position @@ -510,20 +515,25 @@ class Field(AbstractField): assert FieldType.is_staggered(self) stencils = { 2: { - 2: ["E", "N"], # D2Q5 - 4: ["E", "N", "NE", "SE"] # D2Q9 + 2: ["W", "S"], # D2Q5 + 4: ["W", "S", "SW", "NW"] # D2Q9 }, 3: { - 3: ["E", "N", "T"], # D3Q7 - 7: ["E", "N", "T", "TNE", "BNE", "TSE", "BSE "], # D3Q15 - 9: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN"], # D3Q19 - 13: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN", "TNE", "BNE", "TSE", "BSE"] # D3Q27 + 3: ["W", "S", "B"], # D3Q7 + 7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15 + 9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19 + 13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27 } } if not self.index_shape[0] in stencils[self.spatial_dimensions]: raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0])) return stencils[self.spatial_dimensions][self.index_shape[0]] + @property + def staggered_stencil_name(self): + assert FieldType.is_staggered(self) + return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1) + def __call__(self, *args, **kwargs): center = tuple([0] * self.spatial_dimensions) return Field.Access(self, center)(*args, **kwargs) @@ -774,7 +784,7 @@ class Field(AbstractField): assert FieldType.is_staggered(self._field) neighbor = self._field.staggered_stencil[index] neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions) - return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)] + return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)] def _latex(self, _): n = self._field.latex_name if self._field.latex_name else self._field.name