diff --git a/pystencils/fd/derivation.py b/pystencils/fd/derivation.py index af1f4e387b0bdc8bafccd03bad624d986c5c610d..db9460caa666f38bb3ddd8e927b44e2cbd4fc96a 100644 --- a/pystencils/fd/derivation.py +++ b/pystencils/fd/derivation.py @@ -340,11 +340,5 @@ class FiniteDifferenceStaggeredStencilDerivation: from pystencils.stencil import plot plot(pts, data=ws) - def apply(self, field): - if field.index_dimensions == 0: - return sum([field.__getitem__(point) * weight for point, weight in zip(self.points, self.weights)]) - else: - total = field.neighbor_vector(self.points[0]) * self.weights[0] - for point, weight in zip(self.points[1:], self.weights[1:]): - total += field.neighbor_vector(point) * weight - return total + def apply(self, access: Field.Access): + return sum([access.get_shifted(*point) * weight for point, weight in zip(self.points, self.weights)]) diff --git a/pystencils_tests/test_fd_derivation.ipynb b/pystencils_tests/test_fd_derivation.ipynb index 1be42210e774a62513e0d73ce70c57ef51e0240f..0233c293f025c733792eddba60056fa07006da9d 100644 --- a/pystencils_tests/test_fd_derivation.ipynb +++ b/pystencils_tests/test_fd_derivation.ipynb @@ -320,25 +320,25 @@ "c = ps.fields(\"c: [2D]\")\n", "c3 = ps.fields(\"c3: [3D]\")\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 2, (0,)).apply(c) == c[1, 0] - c[0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 2, (0,)).apply(c) == c[0, 0] - c[-1, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 2, (1,)).apply(c) == c[0, 1] - c[0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (1,)).apply(c) == c[0, 0] - c[0, -1]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 2, (0,)).apply(c.center) == c[1, 0] - c[0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 2, (0,)).apply(c.center) == c[0, 0] - c[-1, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 2, (1,)).apply(c.center) == c[0, 1] - c[0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (1,)).apply(c.center) == c[0, 0] - c[0, -1]\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(c3) == c3[1, 0, 0] - c3[0, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 3, (0,)).apply(c3) == c3[0, 0, 0] - c3[-1, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 3, (1,)).apply(c3) == c3[0, 1, 0] - c3[0, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 3, (1,)).apply(c3) == c3[0, 0, 0] - c3[0, -1, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3) == c3[0, 0, 1] - c3[0, 0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3) == c3[0, 0, 0] - c3[0, 0, -1]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(c3.center) == c3[1, 0, 0] - c3[0, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 3, (0,)).apply(c3.center) == c3[0, 0, 0] - c3[-1, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 3, (1,)).apply(c3.center) == c3[0, 1, 0] - c3[0, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 3, (1,)).apply(c3.center) == c3[0, 0, 0] - c3[0, -1, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3.center) == c3[0, 0, 1] - c3[0, 0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3.center) == c3[0, 0, 0] - c3[0, 0, -1]\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c) == \\\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c.center) == \\\n", " (c[1, 0] + c[1, -1] - c[-1, 0] - c[-1, -1])/4\n", "\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c) + \\\n", - " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c) == c[1, 1] - c[0, 0]\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3) + \\\n", - " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (1,)).apply(c3) == c3[1, 1, 0] - c3[0, 0, 0]" + "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c.center) + \\\n", + " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c.center) == c[1, 1] - c[0, 0]\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3.center) + \\\n", + " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (1,)).apply(c3.center) == c3[1, 1, 0] - c3[0, 0, 0]" ] }, { @@ -359,7 +359,7 @@ ], "source": [ "d = FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0, 1))\n", - "assert d.apply(c) == c[0,0] + c[1,1] - c[1,0] - c[0,1]\n", + "assert d.apply(c.center) == c[0,0] + c[1,1] - c[1,0] - c[0,1]\n", "d.visualize()" ] }, @@ -370,8 +370,9 @@ "outputs": [], "source": [ "v3 = ps.fields(\"v(3): [3D]\")\n", - "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(v3) == \\\n", - " sp.Matrix([v3[1,0,0](i) - v3[0,0,0](i) for i in range(*v3.index_shape)])" + "for i in range(*v3.index_shape):\n", + " assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(v3.center_vector[i]) == \\\n", + " v3[1,0,0](i) - v3[0,0,0](i)" ] }, {