From 34732f61d1ee545ae5a8a19f23f7357a8f3d87a2 Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Fri, 22 Nov 2019 11:57:34 +0100 Subject: [PATCH] staggered_access: fix access to directions with mixed sign NW (-1/2, 1/2) and the like were previously mapped to the wrong cell --- pystencils/field.py | 26 ++++++++++++++------------ pystencils_tests/test_field.py | 4 ++++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pystencils/field.py b/pystencils/field.py index d03f29257..817fa202f 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -15,7 +15,7 @@ import pystencils from pystencils.alignedarray import aligned_empty from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol -from pystencils.stencil import direction_string_to_offset, offset_to_direction_string +from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction from pystencils.sympyextensions import is_integer_sequence __all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] @@ -490,19 +490,21 @@ class Field(AbstractField): raise ValueError("Wrong number of spatial indices: " "Got %d, expected %d" % (len(offset), self.spatial_dimensions)) - offset = list(offset) - neighbor = [0] * len(offset) - for i, o in enumerate(offset): - if (o + sp.Rational(1, 2)).is_Integer: - offset[i] += sp.Rational(1, 2) - neighbor[i] = -1 - neighbor = offset_to_direction_string(neighbor) - try: - idx = self.staggered_stencil.index(neighbor) - except ValueError: + neighbor_vec = [0] * len(offset) + for i in range(self.spatial_dimensions): + if (offset[i] + sp.Rational(1, 2)).is_Integer: + neighbor_vec[i] = sp.sign(offset[i]) + neighbor = offset_to_direction_string(neighbor_vec) + if neighbor not in self.staggered_stencil: + neighbor_vec = inverse_direction(neighbor_vec) + neighbor = offset_to_direction_string(neighbor_vec) + if neighbor not in self.staggered_stencil: raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, self.staggered_stencil_name)) - offset = tuple(offset) + + offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) + + idx = self.staggered_stencil.index(neighbor) if self.index_dimensions == 1: # this field stores a scalar value at each staggered position if index is not None: diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index 227b332cd..2552f8e5e 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -135,6 +135,8 @@ def test_staggered(): j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED) assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2))) + assert j1[1, 1](1) == j1.staggered_access((1, sp.Rational(1, 2))) + assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2))) assert j1[0, 1](1) == j1.staggered_access("N") assert j1[0, 0](1) == j1.staggered_access("S") @@ -149,3 +151,5 @@ def test_staggered(): assert k[1, 1](2) == k.staggered_access("NE") assert k[0, 0](2) == k.staggered_access("SW") + + assert k[0, 0](3) == k.staggered_access("NW") -- GitLab