diff --git a/pystencils/fd/derivation.py b/pystencils/fd/derivation.py index af65c62b203042b8ab07221a190a6bf34bff87da..4af4f3e3653e78668ad5b84e9ce9f19e7586b63c 100644 --- a/pystencils/fd/derivation.py +++ b/pystencils/fd/derivation.py @@ -340,11 +340,5 @@ class FiniteDifferenceStaggeredStencilDerivation: from pystencils.stencil import plot plot(pts, data=ws) - def apply(self, field): - if field.index_dimensions == 0: - return sum([field.__getitem__(point) * weight for point, weight in zip(self.points, self.weights)]) - else: - total = field.neighbor_vector(self.points[0]) * self.weights[0] - for point, weight in zip(self.points[1:], self.weights[1:]): - total += field.neighbor_vector(point) * weight - return total + def apply(self, access: Field.Access): + return sum([access.get_shifted(*point) * weight for point, weight in zip(self.points, self.weights)]) diff --git a/pystencils/field.py b/pystencils/field.py index 2cd7096b7fa171ab1cb8d9b83bec66ffc339ca5d..507be0c2cad932eed64419b70e1b55783cff4273 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -845,7 +845,7 @@ class Field(AbstractField): assert FieldType.is_staggered(self._field) neighbor = self._field.staggered_stencil[index] neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions) - return [(o - sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)] + return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)] def _latex(self, _): n = self._field.latex_name if self._field.latex_name else self._field.name diff --git a/pystencils_tests/test_fd_derivation.ipynb b/pystencils_tests/test_fd_derivation.ipynb index 1be42210e774a62513e0d73ce70c57ef51e0240f..0233c293f025c733792eddba60056fa07006da9d 100644 --- a/pystencils_tests/test_fd_derivation.ipynb +++ b/pystencils_tests/test_fd_derivation.ipynb @@ -320,25 +320,25 @@ "c = ps.fields(\"c: [2D]\")\n", "c3 = ps.fields(\"c3: [3D]\")\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 2, (0,)).apply(c) == c[1, 0] - c[0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 2, (0,)).apply(c) == c[0, 0] - c[-1, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 2, (1,)).apply(c) == c[0, 1] - c[0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (1,)).apply(c) == c[0, 0] - c[0, -1]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 2, (0,)).apply(c.center) == c[1, 0] - c[0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 2, (0,)).apply(c.center) == c[0, 0] - c[-1, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 2, (1,)).apply(c.center) == c[0, 1] - c[0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (1,)).apply(c.center) == c[0, 0] - c[0, -1]\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(c3) == c3[1, 0, 0] - c3[0, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 3, (0,)).apply(c3) == c3[0, 0, 0] - c3[-1, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 3, (1,)).apply(c3) == c3[0, 1, 0] - c3[0, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 3, (1,)).apply(c3) == c3[0, 0, 0] - c3[0, -1, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3) == c3[0, 0, 1] - c3[0, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3) == c3[0, 0, 0] - c3[0, 0, -1]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(c3.center) == c3[1, 0, 0] - c3[0, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 3, (0,)).apply(c3.center) == c3[0, 0, 0] - c3[-1, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 3, (1,)).apply(c3.center) == c3[0, 1, 0] - c3[0, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 3, (1,)).apply(c3.center) == c3[0, 0, 0] - c3[0, -1, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3.center) == c3[0, 0, 1] - c3[0, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3.center) == c3[0, 0, 0] - c3[0, 0, -1]\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c) == \\\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c.center) == \\\n", " (c[1, 0] + c[1, -1] - c[-1, 0] - c[-1, -1])/4\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c) + \\\n", - " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c) == c[1, 1] - c[0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3) + \\\n", - " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (1,)).apply(c3) == c3[1, 1, 0] - c3[0, 0, 0]" + "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c.center) + \\\n", + " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c.center) == c[1, 1] - c[0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3.center) + \\\n", + " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (1,)).apply(c3.center) == c3[1, 1, 0] - c3[0, 0, 0]" ] }, { @@ -359,7 +359,7 @@ ], "source": [ "d = FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0, 1))\n", - "assert d.apply(c) == c[0,0] + c[1,1] - c[1,0] - c[0,1]\n", + "assert d.apply(c.center) == c[0,0] + c[1,1] - c[1,0] - c[0,1]\n", "d.visualize()" ] }, @@ -370,8 +370,9 @@ "outputs": [], "source": [ "v3 = ps.fields(\"v(3): [3D]\")\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(v3) == \\\n", - " sp.Matrix([v3[1,0,0](i) - v3[0,0,0](i) for i in range(*v3.index_shape)])" + "for i in range(*v3.index_shape):\n", + " assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(v3.center_vector[i]) == \\\n", + " v3[1,0,0](i) - v3[0,0,0](i)" ] }, { diff --git a/pystencils_tests/test_field.py b/pystencils_tests/test_field.py index ad23682a71b0286446dd13b29e5216cc11d2aeef..a2813e34f9c72c8bbcd848d21993c181cf31fc43 100644 --- a/pystencils_tests/test_field.py +++ b/pystencils_tests/test_field.py @@ -154,12 +154,18 @@ def test_staggered(): for j in range(2)] for i in range(2)]) # D2Q9 - k = ps.fields('k(4) : double[2D]', field_type=FieldType.STAGGERED) - - assert k[1, 1](2) == k.staggered_access("NE") - assert k[0, 0](2) == k.staggered_access("SW") - - assert k[0, 0](3) == k.staggered_access("NW") + k1, k2 = ps.fields('k1(4), k2(2) : double[2D]', field_type=FieldType.STAGGERED) + + assert k1[1, 1](2) == k1.staggered_access("NE") + assert k1[0, 0](2) == k1.staggered_access("SW") + assert k1[0, 0](3) == k1.staggered_access("NW") + + a = k1.staggered_access("NE") + assert a._staggered_offset(a.offsets, a.index[0]) == [sp.Rational(1, 2), sp.Rational(1, 2)] + a = k1.staggered_access("SW") + assert a._staggered_offset(a.offsets, a.index[0]) == [sp.Rational(-1, 2), sp.Rational(-1, 2)] + a = k1.staggered_access("NW") + assert a._staggered_offset(a.offsets, a.index[0]) == [sp.Rational(-1, 2), sp.Rational(1, 2)] # sign reversed when using as flux field r = ps.fields('r(2) : double[2D]', field_type=FieldType.STAGGERED_FLUX)