From a7460abb7344b305190f3c957dc3c189b34a7a80 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Fri, 22 Nov 2019 12:12:23 +0100
Subject: [PATCH] staggered_access: optionally reverse sign for access via
 opposite direction

When storing fluxes on a staggered grid, the usual sign convention is that
fluxes point outward from the cell. Previously, we did not respect that as
staggered_access("E") would return the same thing as staggered_access("W")
would when called from the eastern-next cell. Now, when a field is declared
as STAGGERED_FLUX, it returns an accessor with a prefactor of -1 in that
case. The previous behavior where sign is not reversed is still useful when
e.g. storing sums (e.g. mean values) instead of differenes (e.g. finite
difference fluxes) on the staggered grid.
---
 pystencils/field.py            | 16 +++++++++++++---
 pystencils_tests/test_field.py |  5 +++++
 2 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/pystencils/field.py b/pystencils/field.py
index 817fa202f..fc6f93d60 100644
--- a/pystencils/field.py
+++ b/pystencils/field.py
@@ -34,6 +34,8 @@ class FieldType(Enum):
     CUSTOM = 3
     # staggered field
     STAGGERED = 4
+    # staggered field that reverses sign when accessed via opposite direction
+    STAGGERED_FLUX = 5
 
     @staticmethod
     def is_generic(field):
@@ -58,7 +60,12 @@ class FieldType(Enum):
     @staticmethod
     def is_staggered(field):
         assert isinstance(field, Field)
-        return field.field_type == FieldType.STAGGERED
+        return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX
+
+    @staticmethod
+    def is_staggered_flux(field):
+        assert isinstance(field, Field)
+        return field.field_type == FieldType.STAGGERED_FLUX
 
 
 def fields(description=None, index_dimensions=0, layout=None, field_type=FieldType.GENERIC, **kwargs):
@@ -490,6 +497,7 @@ class Field(AbstractField):
             raise ValueError("Wrong number of spatial indices: "
                              "Got %d, expected %d" % (len(offset), self.spatial_dimensions))
 
+        prefactor = 1
         neighbor_vec = [0] * len(offset)
         for i in range(self.spatial_dimensions):
             if (offset[i] + sp.Rational(1, 2)).is_Integer:
@@ -498,6 +506,8 @@ class Field(AbstractField):
         if neighbor not in self.staggered_stencil:
             neighbor_vec = inverse_direction(neighbor_vec)
             neighbor = offset_to_direction_string(neighbor_vec)
+            if FieldType.is_staggered_flux(self):
+                prefactor = -1
         if neighbor not in self.staggered_stencil:
             raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig,
                              self.staggered_stencil_name))
@@ -509,7 +519,7 @@ class Field(AbstractField):
         if self.index_dimensions == 1:  # this field stores a scalar value at each staggered position
             if index is not None:
                 raise ValueError("Cannot specify an index for a scalar staggered field")
-            return Field.Access(self, offset, (idx,))
+            return prefactor * Field.Access(self, offset, (idx,))
         else:  # this field stores a vector or tensor at each staggered position
             if index is None:
                 raise ValueError("Wrong number of indices: "
@@ -522,7 +532,7 @@ class Field(AbstractField):
                 raise ValueError("Wrong number of indices: "
                                  "Got %d, expected %d" % (len(index), self.index_dimensions - 1))
 
-            return Field.Access(self, offset, (idx, *index))
+            return prefactor * Field.Access(self, offset, (idx, *index))
 
     @property
     def staggered_stencil(self):
diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py
index 2552f8e5e..b1d9af430 100644
--- a/pystencils_tests/test_field.py
+++ b/pystencils_tests/test_field.py
@@ -153,3 +153,8 @@ def test_staggered():
     assert k[0, 0](2) == k.staggered_access("SW")
 
     assert k[0, 0](3) == k.staggered_access("NW")
+
+    # sign reversed when using as flux field
+    r = ps.fields('r(2) : double[2D]', field_type=FieldType.STAGGERED_FLUX)
+    assert r[0, 0](0) == r.staggered_access("W")
+    assert -r[1, 0](0) == r.staggered_access("E")
-- 
GitLab