Commit 10fe510f authored by Michael Kuron's avatar Michael Kuron
Browse files

fix minor regressions introduced with staggered field access

parent 974febd7
...@@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling): ...@@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling):
if hasattr(values_per_cell, '__len__'): if hasattr(values_per_cell, '__len__'):
raise NotImplementedError("Parallel data handling does not support multiple index dimensions") raise NotImplementedError("Parallel data handling does not support multiple index dimensions")
self._fieldInformation[name] = {'ghost_layers': ghost_layers, self._fieldInformation[name] = {
'values_per_cell': values_per_cell, 'ghost_layers': ghost_layers,
'layout': layout, 'values_per_cell': values_per_cell,
'dtype': dtype, 'layout': layout,
'alignment': alignment} 'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf, layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf,
'f': wlb.field.Layout.fzyx, 'f': wlb.field.Layout.fzyx,
......
...@@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling): ...@@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling):
'layout': layout, 'layout': layout,
'dtype': dtype, 'dtype': dtype,
'alignment': alignment, 'alignment': alignment,
'field_type': field_type,
} }
index_dimensions = len(values_per_cell) index_dimensions = len(values_per_cell)
......
...@@ -466,6 +466,7 @@ class Field(AbstractField): ...@@ -466,6 +466,7 @@ class Field(AbstractField):
""" """
assert FieldType.is_staggered(self) assert FieldType.is_staggered(self)
offset_orig = offset
if type(offset) is np.ndarray: if type(offset) is np.ndarray:
offset = tuple(offset) offset = tuple(offset)
if type(offset) is str: if type(offset) is str:
...@@ -484,7 +485,11 @@ class Field(AbstractField): ...@@ -484,7 +485,11 @@ class Field(AbstractField):
offset[i] += sp.Rational(1, 2) offset[i] += sp.Rational(1, 2)
neighbor[i] = 1 neighbor[i] = 1
neighbor = offset_to_direction_string(neighbor) neighbor = offset_to_direction_string(neighbor)
idx = self.staggered_stencil.index(neighbor) try:
idx = self.staggered_stencil.index(neighbor)
except ValueError:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
offset = tuple(offset) offset = tuple(offset)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
...@@ -524,6 +529,11 @@ class Field(AbstractField): ...@@ -524,6 +529,11 @@ class Field(AbstractField):
raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0])) raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
return stencils[self.spatial_dimensions][self.index_shape[0]] return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
def staggered_stencil_name(self):
assert FieldType.is_staggered(self)
return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions) center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs) return Field.Access(self, center)(*args, **kwargs)
...@@ -774,7 +784,7 @@ class Field(AbstractField): ...@@ -774,7 +784,7 @@ class Field(AbstractField):
assert FieldType.is_staggered(self._field) assert FieldType.is_staggered(self._field)
neighbor = self._field.staggered_stencil[index] neighbor = self._field.staggered_stencil[index]
neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions) neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)] return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
def _latex(self, _): def _latex(self, _):
n = self._field.latex_name if self._field.latex_name else self._field.name n = self._field.latex_name if self._field.latex_name else self._field.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment