Commit a723120e authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'staggered' into 'master'

fix minor regressions introduced with !86

See merge request pycodegen/pystencils!88
parents 974febd7 c755fd28
......@@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling):
if hasattr(values_per_cell, '__len__'):
raise NotImplementedError("Parallel data handling does not support multiple index dimensions")
self._fieldInformation[name] = {'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
'layout': layout,
'dtype': dtype,
'alignment': alignment}
self._fieldInformation[name] = {
'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
'layout': layout,
'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf,
'f': wlb.field.Layout.fzyx,
......
......@@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling):
'layout': layout,
'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
index_dimensions = len(values_per_cell)
......
......@@ -466,6 +466,7 @@ class Field(AbstractField):
"""
assert FieldType.is_staggered(self)
offset_orig = offset
if type(offset) is np.ndarray:
offset = tuple(offset)
if type(offset) is str:
......@@ -482,9 +483,13 @@ class Field(AbstractField):
for i, o in enumerate(offset):
if (o + sp.Rational(1, 2)).is_Integer:
offset[i] += sp.Rational(1, 2)
neighbor[i] = 1
neighbor[i] = -1
neighbor = offset_to_direction_string(neighbor)
idx = self.staggered_stencil.index(neighbor)
try:
idx = self.staggered_stencil.index(neighbor)
except ValueError:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
offset = tuple(offset)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
......@@ -510,20 +515,25 @@ class Field(AbstractField):
assert FieldType.is_staggered(self)
stencils = {
2: {
2: ["E", "N"], # D2Q5
4: ["E", "N", "NE", "SE"] # D2Q9
2: ["W", "S"], # D2Q5
4: ["W", "S", "SW", "NW"] # D2Q9
},
3: {
3: ["E", "N", "T"], # D3Q7
7: ["E", "N", "T", "TNE", "BNE", "TSE", "BSE "], # D3Q15
9: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN"], # D3Q19
13: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN", "TNE", "BNE", "TSE", "BSE"] # D3Q27
3: ["W", "S", "B"], # D3Q7
7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15
9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19
13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27
}
}
if not self.index_shape[0] in stencils[self.spatial_dimensions]:
raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
def staggered_stencil_name(self):
assert FieldType.is_staggered(self)
return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1)
def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs)
......@@ -774,7 +784,7 @@ class Field(AbstractField):
assert FieldType.is_staggered(self._field)
neighbor = self._field.staggered_stencil[index]
neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)]
return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
def _latex(self, _):
n = self._field.latex_name if self._field.latex_name else self._field.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment