Commit 34732f61 authored by Michael Kuron's avatar Michael Kuron
Browse files

staggered_access: fix access to directions with mixed sign

NW (-1/2, 1/2) and the like were previously mapped to the wrong cell
parent 9888c32f
......@@ -15,7 +15,7 @@ import pystencils
from pystencils.alignedarray import aligned_empty
from pystencils.data_types import StructType, TypedSymbol, create_type
from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string
from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction
from pystencils.sympyextensions import is_integer_sequence
__all__ = ['Field', 'fields', 'FieldType', 'AbstractField']
......@@ -490,19 +490,21 @@ class Field(AbstractField):
raise ValueError("Wrong number of spatial indices: "
"Got %d, expected %d" % (len(offset), self.spatial_dimensions))
offset = list(offset)
neighbor = [0] * len(offset)
for i, o in enumerate(offset):
if (o + sp.Rational(1, 2)).is_Integer:
offset[i] += sp.Rational(1, 2)
neighbor[i] = -1
neighbor = offset_to_direction_string(neighbor)
try:
idx = self.staggered_stencil.index(neighbor)
except ValueError:
neighbor_vec = [0] * len(offset)
for i in range(self.spatial_dimensions):
if (offset[i] + sp.Rational(1, 2)).is_Integer:
neighbor_vec[i] = sp.sign(offset[i])
neighbor = offset_to_direction_string(neighbor_vec)
if neighbor not in self.staggered_stencil:
neighbor_vec = inverse_direction(neighbor_vec)
neighbor = offset_to_direction_string(neighbor_vec)
if neighbor not in self.staggered_stencil:
raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
self.staggered_stencil_name))
offset = tuple(offset)
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
idx = self.staggered_stencil.index(neighbor)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
if index is not None:
......
......@@ -135,6 +135,8 @@ def test_staggered():
j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED)
assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2)))
assert j1[1, 1](1) == j1.staggered_access((1, sp.Rational(1, 2)))
assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2)))
assert j1[0, 1](1) == j1.staggered_access("N")
assert j1[0, 0](1) == j1.staggered_access("S")
......@@ -149,3 +151,5 @@ def test_staggered():
assert k[1, 1](2) == k.staggered_access("NE")
assert k[0, 0](2) == k.staggered_access("SW")
assert k[0, 0](3) == k.staggered_access("NW")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment