diff --git a/pystencils/field.py b/pystencils/field.py index 817fa202f60f882799c78eee6d91cfcc3f288146..fc6f93d60b3e6d20e9c5fbb4ca08bfc57b684610 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -34,6 +34,8 @@ class FieldType(Enum): CUSTOM = 3 # staggered field STAGGERED = 4 + # staggered field that reverses sign when accessed via opposite direction + STAGGERED_FLUX = 5 @staticmethod def is_generic(field): @@ -58,7 +60,12 @@ class FieldType(Enum): @staticmethod def is_staggered(field): assert isinstance(field, Field) - return field.field_type == FieldType.STAGGERED + return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX + + @staticmethod + def is_staggered_flux(field): + assert isinstance(field, Field) + return field.field_type == FieldType.STAGGERED_FLUX def fields(description=None, index_dimensions=0, layout=None, field_type=FieldType.GENERIC, **kwargs): @@ -490,6 +497,7 @@ class Field(AbstractField): raise ValueError("Wrong number of spatial indices: " "Got %d, expected %d" % (len(offset), self.spatial_dimensions)) + prefactor = 1 neighbor_vec = [0] * len(offset) for i in range(self.spatial_dimensions): if (offset[i] + sp.Rational(1, 2)).is_Integer: @@ -498,6 +506,8 @@ class Field(AbstractField): if neighbor not in self.staggered_stencil: neighbor_vec = inverse_direction(neighbor_vec) neighbor = offset_to_direction_string(neighbor_vec) + if FieldType.is_staggered_flux(self): + prefactor = -1 if neighbor not in self.staggered_stencil: raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, self.staggered_stencil_name)) @@ -509,7 +519,7 @@ class Field(AbstractField): if self.index_dimensions == 1: # this field stores a scalar value at each staggered position if index is not None: raise ValueError("Cannot specify an index for a scalar staggered field") - return Field.Access(self, offset, (idx,)) + return prefactor * Field.Access(self, offset, (idx,)) else: # this field stores a vector or tensor at each staggered position if index is None: raise ValueError("Wrong number of indices: " @@ -522,7 +532,7 @@ class Field(AbstractField): raise ValueError("Wrong number of indices: " "Got %d, expected %d" % (len(index), self.index_dimensions - 1)) - return Field.Access(self, offset, (idx, *index)) + return prefactor * Field.Access(self, offset, (idx, *index)) @property def staggered_stencil(self): diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index 2552f8e5ec3643c9b356c2380449b9de701bb982..b1d9af430d05c3f3bfdc6c177cf96de9cdd7183b 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -153,3 +153,8 @@ def test_staggered(): assert k[0, 0](2) == k.staggered_access("SW") assert k[0, 0](3) == k.staggered_access("NW") + + # sign reversed when using as flux field + r = ps.fields('r(2) : double[2D]', field_type=FieldType.STAGGERED_FLUX) + assert r[0, 0](0) == r.staggered_access("W") + assert -r[1, 0](0) == r.staggered_access("E")