diff --git a/pystencils/field.py b/pystencils/field.py index d03f292579fca9bccc58c97dcd4225248d188cb7..817fa202f60f882799c78eee6d91cfcc3f288146 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -15,7 +15,7 @@ import pystencils from pystencils.alignedarray import aligned_empty from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol -from pystencils.stencil import direction_string_to_offset, offset_to_direction_string +from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction from pystencils.sympyextensions import is_integer_sequence __all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] @@ -490,19 +490,21 @@ class Field(AbstractField): raise ValueError("Wrong number of spatial indices: " "Got %d, expected %d" % (len(offset), self.spatial_dimensions)) - offset = list(offset) - neighbor = [0] * len(offset) - for i, o in enumerate(offset): - if (o + sp.Rational(1, 2)).is_Integer: - offset[i] += sp.Rational(1, 2) - neighbor[i] = -1 - neighbor = offset_to_direction_string(neighbor) - try: - idx = self.staggered_stencil.index(neighbor) - except ValueError: + neighbor_vec = [0] * len(offset) + for i in range(self.spatial_dimensions): + if (offset[i] + sp.Rational(1, 2)).is_Integer: + neighbor_vec[i] = sp.sign(offset[i]) + neighbor = offset_to_direction_string(neighbor_vec) + if neighbor not in self.staggered_stencil: + neighbor_vec = inverse_direction(neighbor_vec) + neighbor = offset_to_direction_string(neighbor_vec) + if neighbor not in self.staggered_stencil: raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, self.staggered_stencil_name)) - offset = tuple(offset) + + offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) + + idx = self.staggered_stencil.index(neighbor) if self.index_dimensions == 1: # this field stores a scalar value at each staggered position if index is not None: diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index 227b332cd1dbf1f68bab43d1b7042b89cc44b8ad..2552f8e5ec3643c9b356c2380449b9de701bb982 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -135,6 +135,8 @@ def test_staggered(): j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED) assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2))) + assert j1[1, 1](1) == j1.staggered_access((1, sp.Rational(1, 2))) + assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2))) assert j1[0, 1](1) == j1.staggered_access("N") assert j1[0, 0](1) == j1.staggered_access("S") @@ -149,3 +151,5 @@ def test_staggered(): assert k[1, 1](2) == k.staggered_access("NE") assert k[0, 0](2) == k.staggered_access("SW") + + assert k[0, 0](3) == k.staggered_access("NW")