Skip to content
Snippets Groups Projects
Commit 10fe510f authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

fix minor regressions introduced with staggered field access

parent 974febd7
Branches
Tags
No related merge requests found
......@@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling):
if hasattr(values_per_cell, '__len__'):
raise NotImplementedError("Parallel data handling does not support multiple index dimensions")
self._fieldInformation[name] = {'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
'layout': layout,
'dtype': dtype,
'alignment': alignment}
self._fieldInformation[name] = {
'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
'layout': layout,
'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf,
'f': wlb.field.Layout.fzyx,
......
......@@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling):
'layout': layout,
'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
index_dimensions = len(values_per_cell)
......
......@@ -466,6 +466,7 @@ class Field(AbstractField):
"""
assert FieldType.is_staggered(self)
offset_orig = offset
if type(offset) is np.ndarray:
offset = tuple(offset)
if type(offset) is str:
......@@ -484,7 +485,11 @@ class Field(AbstractField):
offset[i] += sp.Rational(1, 2)
neighbor[i] = 1
neighbor = offset_to_direction_string(neighbor)
idx = self.staggered_stencil.index(neighbor)
try:
idx = self.staggered_stencil.index(neighbor)
except ValueError:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
offset = tuple(offset)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
......@@ -524,6 +529,11 @@ class Field(AbstractField):
raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
def staggered_stencil_name(self):
assert FieldType.is_staggered(self)
return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1)
def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs)
......@@ -774,7 +784,7 @@ class Field(AbstractField):
assert FieldType.is_staggered(self._field)
neighbor = self._field.staggered_stencil[index]
neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)]
return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
def _latex(self, _):
n = self._field.latex_name if self._field.latex_name else self._field.name
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment