Skip to content
Snippets Groups Projects

fix minor regressions introduced with !86

Merged Michael Kuron requested to merge staggered into master
Compare and Show latest version
1 file
+ 18
8
Preferences
Compare changes
+ 18
8
@@ -466,6 +466,7 @@ class Field(AbstractField):
"""
assert FieldType.is_staggered(self)
offset_orig = offset
if type(offset) is np.ndarray:
offset = tuple(offset)
if type(offset) is str:
@@ -482,9 +483,13 @@ class Field(AbstractField):
for i, o in enumerate(offset):
if (o + sp.Rational(1, 2)).is_Integer:
offset[i] += sp.Rational(1, 2)
neighbor[i] = 1
neighbor[i] = -1
neighbor = offset_to_direction_string(neighbor)
idx = self.staggered_stencil.index(neighbor)
try:
idx = self.staggered_stencil.index(neighbor)
except ValueError:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
offset = tuple(offset)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
@@ -510,20 +515,25 @@ class Field(AbstractField):
assert FieldType.is_staggered(self)
stencils = {
2: {
2: ["E", "N"], # D2Q5
4: ["E", "N", "NE", "SE"] # D2Q9
2: ["W", "S"], # D2Q5
4: ["W", "S", "SW", "NW"] # D2Q9
},
3: {
3: ["E", "N", "T"], # D3Q7
7: ["E", "N", "T", "TNE", "BNE", "TSE", "BSE "], # D3Q15
9: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN"], # D3Q19
13: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN", "TNE", "BNE", "TSE", "BSE"] # D3Q27
3: ["W", "S", "B"], # D3Q7
7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15
9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19
13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27
}
}
if not self.index_shape[0] in stencils[self.spatial_dimensions]:
raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
def staggered_stencil_name(self):
assert FieldType.is_staggered(self)
return "D%dQ%d" % (self.spatial_dimensions, self.index_shape[0] * 2 + 1)
def __call__(self, *args, **kwargs):
center = tuple([0] * self.spatial_dimensions)
return Field.Access(self, center)(*args, **kwargs)