diff --git a/pystencils/field.py b/pystencils/field.py index d03f292579fca9bccc58c97dcd4225248d188cb7..6b3fa9f8c4c5fa21cd799fc9ff37d090224794af 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -15,7 +15,7 @@ import pystencils from pystencils.alignedarray import aligned_empty from pystencils.data_types import StructType, TypedSymbol, create_type from pystencils.kernelparameters import FieldShapeSymbol, FieldStrideSymbol -from pystencils.stencil import direction_string_to_offset, offset_to_direction_string +from pystencils.stencil import direction_string_to_offset, offset_to_direction_string, inverse_direction from pystencils.sympyextensions import is_integer_sequence __all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] @@ -34,6 +34,8 @@ class FieldType(Enum): CUSTOM = 3 # staggered field STAGGERED = 4 + # staggered field that reverses sign when accessed via opposite direction + STAGGERED_FLUX = 5 @staticmethod def is_generic(field): @@ -58,7 +60,12 @@ class FieldType(Enum): @staticmethod def is_staggered(field): assert isinstance(field, Field) - return field.field_type == FieldType.STAGGERED + return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX + + @staticmethod + def is_staggered_flux(field): + assert isinstance(field, Field) + return field.field_type == FieldType.STAGGERED_FLUX def fields(description=None, index_dimensions=0, layout=None, field_type=FieldType.GENERIC, **kwargs): @@ -419,13 +426,15 @@ class Field(AbstractField): index_shape = self.index_shape if len(index_shape) == 0: return sp.Matrix([self.center]) - if len(index_shape) == 1: + elif len(index_shape) == 1: return sp.Matrix([self(i) for i in range(index_shape[0])]) elif len(index_shape) == 2: - def cb(*args): - r = self.__call__(*args) - return r - return sp.Matrix(*index_shape, cb) + return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])]) + elif len(index_shape) == 3: + return sp.Matrix([[[self(i, j, k) for k in range(index_shape[2])] + for j in range(index_shape[1])] for i in range(index_shape[0])]) + else: + raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions") @property def center(self): @@ -490,24 +499,29 @@ class Field(AbstractField): raise ValueError("Wrong number of spatial indices: " "Got %d, expected %d" % (len(offset), self.spatial_dimensions)) - offset = list(offset) - neighbor = [0] * len(offset) - for i, o in enumerate(offset): - if (o + sp.Rational(1, 2)).is_Integer: - offset[i] += sp.Rational(1, 2) - neighbor[i] = -1 - neighbor = offset_to_direction_string(neighbor) - try: - idx = self.staggered_stencil.index(neighbor) - except ValueError: + prefactor = 1 + neighbor_vec = [0] * len(offset) + for i in range(self.spatial_dimensions): + if (offset[i] + sp.Rational(1, 2)).is_Integer: + neighbor_vec[i] = sp.sign(offset[i]) + neighbor = offset_to_direction_string(neighbor_vec) + if neighbor not in self.staggered_stencil: + neighbor_vec = inverse_direction(neighbor_vec) + neighbor = offset_to_direction_string(neighbor_vec) + if FieldType.is_staggered_flux(self): + prefactor = -1 + if neighbor not in self.staggered_stencil: raise ValueError("{} is not a valid neighbor for the {} stencil".format(offset_orig, self.staggered_stencil_name)) - offset = tuple(offset) + + offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec)) + + idx = self.staggered_stencil.index(neighbor) if self.index_dimensions == 1: # this field stores a scalar value at each staggered position if index is not None: raise ValueError("Cannot specify an index for a scalar staggered field") - return Field.Access(self, offset, (idx,)) + return prefactor * Field.Access(self, offset, (idx,)) else: # this field stores a vector or tensor at each staggered position if index is None: raise ValueError("Wrong number of indices: " @@ -520,7 +534,21 @@ class Field(AbstractField): raise ValueError("Wrong number of indices: " "Got %d, expected %d" % (len(index), self.index_dimensions - 1)) - return Field.Access(self, offset, (idx, *index)) + return prefactor * Field.Access(self, offset, (idx, *index)) + + def staggered_vector_access(self, offset): + """Like staggered_access, but returns the entire vector/tensor stored at offset.""" + assert FieldType.is_staggered(self) + + if self.index_dimensions == 1: + return sp.Matrix([self.staggered_access(offset)]) + elif self.index_dimensions == 2: + return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])]) + elif self.index_dimensions == 3: + return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])] + for i in range(self.index_shape[1])]) + else: + raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions") @property def staggered_stencil(self): diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index 227b332cd1dbf1f68bab43d1b7042b89cc44b8ad..ad23682a71b0286446dd13b29e5216cc11d2aeef 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -28,6 +28,9 @@ def test_field_basic(): assert neighbor.offsets == (-1, 1) assert '_' in neighbor._latex('dummy') + f = Field.create_fixed_size('f', (8, 8, 2, 2, 2), index_dimensions=3) + assert f.center_vector == sp.Matrix([[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]) + f = Field.create_generic('f', spatial_dimensions=5, index_dimensions=2) field_access = f[1, -1, 2, -3, 0](1, 0) assert field_access.offsets == (1, -1, 2, -3, 0) @@ -135,17 +138,30 @@ def test_staggered(): j1, j2, j3 = ps.fields('j1(2), j2(2,2), j3(2,2,2) : double[2D]', field_type=FieldType.STAGGERED) assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2))) + assert j1[1, 1](1) == j1.staggered_access((1, sp.Rational(1, 2))) + assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2))) assert j1[0, 1](1) == j1.staggered_access("N") assert j1[0, 0](1) == j1.staggered_access("S") + assert j1.staggered_vector_access("N") == sp.Matrix([j1.staggered_access("N")]) assert j2[0, 1](1, 1) == j2.staggered_access((0, sp.Rational(1, 2)), 1) assert j2[0, 1](1, 1) == j2.staggered_access("N", 1) + assert j2.staggered_vector_access("N") == sp.Matrix([j2.staggered_access("N", 0), j2.staggered_access("N", 1)]) assert j3[0, 1](1, 1, 1) == j3.staggered_access((0, sp.Rational(1, 2)), (1, 1)) assert j3[0, 1](1, 1, 1) == j3.staggered_access("N", (1, 1)) + assert j3.staggered_vector_access("N") == sp.Matrix([[j3.staggered_access("N", (i, j)) + for j in range(2)] for i in range(2)]) # D2Q9 k = ps.fields('k(4) : double[2D]', field_type=FieldType.STAGGERED) assert k[1, 1](2) == k.staggered_access("NE") assert k[0, 0](2) == k.staggered_access("SW") + + assert k[0, 0](3) == k.staggered_access("NW") + + # sign reversed when using as flux field + r = ps.fields('r(2) : double[2D]', field_type=FieldType.STAGGERED_FLUX) + assert r[0, 0](0) == r.staggered_access("W") + assert -r[1, 0](0) == r.staggered_access("E")