Skip to content
Snippets Groups Projects
Commit a723120e authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'staggered' into 'master'

fix minor regressions introduced with !86

See merge request pycodegen/pystencils!88
parents 974febd7 c755fd28
No related merge requests found
......@@ -109,11 +109,14 @@ class ParallelDataHandling(DataHandling):
if hasattr(values_per_cell, '__len__'):
raise NotImplementedError("Parallel data handling does not support multiple index dimensions")
self._fieldInformation[name] = {'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
'layout': layout,
'dtype': dtype,
'alignment': alignment}
self._fieldInformation[name] = {
'ghost_layers': ghost_layers,
'values_per_cell': values_per_cell,
'layout': layout,
'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
layout_map = {'fzyx': wlb.field.Layout.fzyx, 'zyxf': wlb.field.Layout.zyxf,
'f': wlb.field.Layout.fzyx,
......
......@@ -100,6 +100,7 @@ class SerialDataHandling(DataHandling):
'layout': layout,
'dtype': dtype,
'alignment': alignment,
'field_type': field_type,
}
index_dimensions = len(values_per_cell)
......
......@@ -466,6 +466,7 @@ class Field(AbstractField):
"""
assert FieldType.is_staggered(self)
offset_orig = offset
if type(offset) is np.ndarray:
offset = tuple(offset)
if type(offset) is str:
......@@ -482,9 +483,13 @@ class Field(AbstractField):
for i, o in enumerate(offset):
if (o + sp.Rational(1, 2)).is_Integer:
offset[i] += sp.Rational(1, 2)
neighbor[i] = 1
neighbor[i] = -1
neighbor = offset_to_direction_string(neighbor)
idx = self.staggered_stencil.index(neighbor)
try:
idx = self.staggered_stencil.index(neighbor)
except ValueError:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
offset = tuple(offset)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
......@@ -510,20 +515,25 @@ class Field(AbstractField):
assert FieldType.is_staggered(self)
stencils = {
2: {
2: ["E", "N"], # D2Q5
4: ["E", "N", "NE", "SE"] # D2Q9
2: ["W", "S"], # D2Q5
4: ["W", "S", "SW", "NW"] # D2Q9
},
3: {
3: ["E", "N", "T"], # D3Q7
7: ["E", "N", "T", "TNE", "BNE", "TSE", "BSE "], # D3Q15
9: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN"], # D3Q19
13: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN", "TNE", "BNE", "TSE", "BSE"] # D3Q27
3: ["W", "S", "B"], # D3Q7
7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15
9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19
13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27
}
}
if not self.index_shape[0] in stencils[self.spatial_dimensions]:
raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
def staggered_stencil_name(self):
assert FieldType.is_staggered(self)
return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1)
def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs)
......@@ -774,7 +784,7 @@ class Field(AbstractField):
assert FieldType.is_staggered(self._field)
neighbor = self._field.staggered_stencil[index]
neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)]
return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
def _latex(self, _):
n = self._field.latex_name if self._field.latex_name else self._field.name
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment