From 10fe510f86e761a9870887db88073da7908a7d7e Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Tue, 19 Nov 2019 15:11:24 +0100 Subject: [PATCH] fix minor regressions introduced with staggered field access --- pystencils/datahandling/parallel_datahandling.py | 13 ++++++++----- pystencils/datahandling/serial_datahandling.py | 1 + pystencils/field.py | 14 ++++++++++++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pystencils/datahandling/parallel_datahandling.py b/pystencils/datahandling/parallel_datahandling.py index 54f26806b..535933300 100644 --- a/pystencils/datahandling/parallel_datahandling.py +++ b/pystencils/datahandling/parallel_datahandling.py @@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling): if hasattr(values_per_cell, '__len__'): raise NotImplementedError("Parallel data handling does not support multiple index dimensions") - self._fieldInformation[name] = {'ghost_layers': ghost_layers, - 'values_per_cell': values_per_cell, - 'layout': layout, - 'dtype': dtype, - 'alignment': alignment} + self._fieldInformation[name] = { + 'ghost_layers': ghost_layers, + 'values_per_cell': values_per_cell, + 'layout': layout, + 'dtype': dtype, + 'alignment': alignment, + 'field_type': field_type, + } layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf, 'f': wlb.field.Layout.fzyx, diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py index f8b0a4a1d..697b6b667 100644 --- a/pystencils/datahandling/serial_datahandling.py +++ b/pystencils/datahandling/serial_datahandling.py @@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling): 'layout': layout, 'dtype': dtype, 'alignment': alignment, + 'field_type': field_type, } index_dimensions = len(values_per_cell) diff --git a/pystencils/field.py b/pystencils/field.py index 0c196f4ba..76bd2be40 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -466,6 +466,7 @@ class Field(AbstractField): """ assert FieldType.is_staggered(self) + offset_orig = offset if type(offset) is np.ndarray: offset = tuple(offset) if type(offset) is str: @@ -484,7 +485,11 @@ class Field(AbstractField): offset[i] += sp.Rational(1, 2) neighbor[i] = 1 neighbor = offset_to_direction_string(neighbor) - idx = self.staggered_stencil.index(neighbor) + try: + idx = self.staggered_stencil.index(neighbor) + except ValueError: + raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, + self.staggered_stencil_name)) offset = tuple(offset) if self.index_dimensions == 1: # this field stores a scalar value at each staggered position @@ -524,6 +529,11 @@ class Field(AbstractField): raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0])) return stencils[self.spatial_dimensions][self.index_shape[0]] + @property + def staggered_stencil_name(self): + assert FieldType.is_staggered(self) + return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1) + def __call__(self, *args, **kwargs): center = tuple([0] * self.spatial_dimensions) return Field.Access(self, center)(*args, **kwargs) @@ -774,7 +784,7 @@ class Field(AbstractField): assert FieldType.is_staggered(self._field) neighbor = self._field.staggered_stencil[index] neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions) - return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)] + return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)] def _latex(self, _): n = self._field.latex_name if self._field.latex_name else self._field.name -- GitLab