From 51a476a2d8636de55168ecc39d275360cbedbb0d Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Fri, 15 Nov 2019 12:12:48 +0100
Subject: [PATCH] Staggered fields with face/edge neighbor points

---
 pystencils/field.py                   | 40 +++++++++++++++++++++++----
 pystencils_tests/test_field.py        |  8 ++++++
 pystencils_tests/test_loop_cutting.py |  2 +-
 3 files changed, 44 insertions(+), 6 deletions(-)

diff --git a/pystencils/field.py b/pystencils/field.py
index 085d45675..0c196f4ba 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -326,6 +326,8 @@ class Field(AbstractField):
             0 for _ in range(self.spatial_dimensions)
         ))  # type: tuple[float,sp.Symbol]
         self.coordinate_transform = sp.eye(self.spatial_dimensions)
+        if field_type == FieldType.STAGGERED:
+            assert self.staggered_stencil
 
     def new_field_with_different_name(self, new_name):
         if self.has_fixed_shape:
@@ -476,10 +478,13 @@ class Field(AbstractField):
                              "Got %d, expected %d" % (len(offset), self.spatial_dimensions))
 
         offset = list(offset)
+        neighbor = [0] * len(offset)
         for i, o in enumerate(offset):
             if (o + sp.Rational(1, 2)).is_Integer:
                 offset[i] += sp.Rational(1, 2)
-                idx = i
+                neighbor[i] = 1
+        neighbor = offset_to_direction_string(neighbor)
+        idx = self.staggered_stencil.index(neighbor)
         offset = tuple(offset)
 
         if self.index_dimensions == 1:  # this field stores a scalar value at each staggered position
@@ -500,6 +505,25 @@ class Field(AbstractField):
 
             return Field.Access(self, offset, (idx, *index))
 
+    @property
+    def staggered_stencil(self):
+        assert FieldType.is_staggered(self)
+        stencils = {
+            2: {
+                2: ["E", "N"],  # D2Q5
+                4: ["E", "N", "NE", "SE"]  # D2Q9
+            },
+            3: {
+                3: ["E", "N", "T"],  # D3Q7
+                7: ["E", "N", "T", "TNE", "BNE", "TSE", "BSE "],  # D3Q15
+                9: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN"],  # D3Q19
+                13: ["E", "N", "T", "NE", "SE", "TE", "BE", "TN", "BN", "TNE", "BNE", "TSE", "BSE"]  # D3Q27
+            }
+        }
+        if not self.index_shape[0] in stencils[self.spatial_dimensions]:
+            raise ValueError("No known stencil has {} staggered points".format(self.index_shape[0]))
+        return stencils[self.spatial_dimensions][self.index_shape[0]]
+
     def __call__(self, *args, **kwargs):
         center = tuple([0] * self.spatial_dimensions)
         return Field.Access(self, center)(*args, **kwargs)
@@ -746,12 +770,18 @@ class Field(AbstractField):
             super_class_contents = super(Field.Access, self)._hashable_content()
             return (super_class_contents, self._field.hashable_contents(), *self._index, *self._offsets)
 
+        def _staggered_offset(self, offsets, index):
+            assert FieldType.is_staggered(self._field)
+            neighbor = self._field.staggered_stencil[index]
+            neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
+            return [(o - sp.Rational(neighbor[i], 2)) for i, o in enumerate(offsets)]
+
         def _latex(self, _):
             n = self._field.latex_name if self._field.latex_name else self._field.name
             offset_str = ",".join([sp.latex(o) for o in self.offsets])
             if FieldType.is_staggered(self._field):
-                offset_str = ",".join([sp.latex(o - sp.Rational(int(i == self.index[0]), 2))
-                                       for i, o in enumerate(self.offsets)])
+                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
+                                       for i in range(len(self.offsets))])
             if self.is_absolute_access:
                 offset_str = "\\mathbf{}".format(offset_str)
             elif self.field.spatial_dimensions > 1:
@@ -773,8 +803,8 @@ class Field(AbstractField):
             n = self._field.latex_name if self._field.latex_name else self._field.name
             offset_str = ",".join([sp.latex(o) for o in self.offsets])
             if FieldType.is_staggered(self._field):
-                offset_str = ",".join([sp.latex(o - sp.Rational(int(i == self.index[0]), 2))
-                                       for i, o in enumerate(self.offsets)])
+                offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
+                                       for i in range(len(self.offsets))])
             if self.is_absolute_access:
                 offset_str = "[abs]{}".format(offset_str)
 
diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py
index 2a75f5a09..227b332cd 100644
--- a/pystencils_tests/test_field.py
+++ b/pystencils_tests/test_field.py
@@ -131,13 +131,21 @@ def test_itemsize():
 
 def test_staggered():
 
+    # D2Q5
     j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED)
 
     assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2)))
     assert j1[0, 1](1) == j1.staggered_access("N")
+    assert j1[0, 0](1) == j1.staggered_access("S")
 
     assert j2[0, 1](1, 1) == j2.staggered_access((0, sp.Rational(1, 2)), 1)
     assert j2[0, 1](1, 1) == j2.staggered_access("N", 1)
 
     assert j3[0, 1](1, 1, 1) == j3.staggered_access((0, sp.Rational(1, 2)), (1, 1))
     assert j3[0, 1](1, 1, 1) == j3.staggered_access("N", (1, 1))
+
+    # D2Q9
+    k = ps.fields('k(4) : double[2D]', field_type=FieldType.STAGGERED)
+
+    assert k[1, 1](2) == k.staggered_access("NE")
+    assert k[0, 0](2) == k.staggered_access("SW")
diff --git a/pystencils_tests/test_loop_cutting.py b/pystencils_tests/test_loop_cutting.py
index cd8623805..999e7b52a 100644
--- a/pystencils_tests/test_loop_cutting.py
+++ b/pystencils_tests/test_loop_cutting.py
@@ -36,7 +36,7 @@ def test_staggered_iteration():
     fields_fixed = (Field.create_from_numpy_array('f', f_arr),
                     Field.create_from_numpy_array('s', s_arr, index_dimensions=1, field_type=FieldType.STAGGERED))
     fields_var = (Field.create_generic('f', 2),
-                  Field.create_generic('s', 2, index_dimensions=1, field_type=FieldType.STAGGERED))
+                  Field.create_generic('s', 2, index_dimensions=1, index_shape=(dim,), field_type=FieldType.STAGGERED))
 
     for f, s in [fields_var, fields_fixed]:
         # --- Manual
-- 
GitLab