From 37aff58bf220ad7d2af964ed7959138b484de405 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Tue, 17 Dec 2019 14:29:26 +0100
Subject: [PATCH] FiniteDifferenceStaggeredStencilDerivation must be applied to
 field access

otherwise the index gets lost
---
 pystencils/fd/derivation.py               | 10 ++----
 pystencils_tests/test_fd_derivation.ipynb | 37 ++++++++++++-----------
 2 files changed, 21 insertions(+), 26 deletions(-)

diff --git a/pystencils/fd/derivation.py b/pystencils/fd/derivation.py
index af1f4e387..db9460caa 100644
--- a/pystencils/fd/derivation.py
+++ b/pystencils/fd/derivation.py
@@ -340,11 +340,5 @@ class FiniteDifferenceStaggeredStencilDerivation:
         from pystencils.stencil import plot
         plot(pts, data=ws)
 
-    def apply(self, field):
-        if field.index_dimensions == 0:
-            return sum([field.__getitem__(point) * weight for point, weight in zip(self.points, self.weights)])
-        else:
-            total = field.neighbor_vector(self.points[0]) * self.weights[0]
-            for point, weight in zip(self.points[1:], self.weights[1:]):
-                total += field.neighbor_vector(point) * weight
-            return total
+    def apply(self, access: Field.Access):
+        return sum([access.get_shifted(*point) * weight for point, weight in zip(self.points, self.weights)])
diff --git a/pystencils_tests/test_fd_derivation.ipynb b/pystencils_tests/test_fd_derivation.ipynb
index 1be42210e..0233c293f 100644
--- a/pystencils_tests/test_fd_derivation.ipynb
+++ b/pystencils_tests/test_fd_derivation.ipynb
@@ -320,25 +320,25 @@
     "c = ps.fields(\"c: [2D]\")\n",
     "c3 = ps.fields(\"c3: [3D]\")\n",
     "\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 2, (0,)).apply(c) == c[1, 0] - c[0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 2, (0,)).apply(c) == c[0, 0] - c[-1, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 2, (1,)).apply(c) == c[0, 1] - c[0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (1,)).apply(c) == c[0, 0] - c[0, -1]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 2, (0,)).apply(c.center) == c[1, 0] - c[0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 2, (0,)).apply(c.center) == c[0, 0] - c[-1, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 2, (1,)).apply(c.center) == c[0, 1] - c[0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (1,)).apply(c.center) == c[0, 0] - c[0, -1]\n",
     "\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(c3) == c3[1, 0, 0] - c3[0, 0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 3, (0,)).apply(c3) == c3[0, 0, 0] - c3[-1, 0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 3, (1,)).apply(c3) == c3[0, 1, 0] - c3[0, 0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 3, (1,)).apply(c3) == c3[0, 0, 0] - c3[0, -1, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3) == c3[0, 0, 1] - c3[0, 0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3) == c3[0, 0, 0] - c3[0, 0, -1]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(c3.center) == c3[1, 0, 0] - c3[0, 0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"W\", 3, (0,)).apply(c3.center) == c3[0, 0, 0] - c3[-1, 0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"N\", 3, (1,)).apply(c3.center) == c3[0, 1, 0] - c3[0, 0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 3, (1,)).apply(c3.center) == c3[0, 0, 0] - c3[0, -1, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3.center) == c3[0, 0, 1] - c3[0, 0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3.center) == c3[0, 0, 0] - c3[0, 0, -1]\n",
     "\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c) == \\\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c.center) == \\\n",
     "       (c[1, 0] + c[1, -1] - c[-1, 0] - c[-1, -1])/4\n",
     "\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c) + \\\n",
-    "       FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c) == c[1, 1] - c[0, 0]\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3) + \\\n",
-    "       FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (1,)).apply(c3) == c3[1, 1, 0] - c3[0, 0, 0]"
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c.center) + \\\n",
+    "       FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c.center) == c[1, 1] - c[0, 0]\n",
+    "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3.center) + \\\n",
+    "       FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (1,)).apply(c3.center) == c3[1, 1, 0] - c3[0, 0, 0]"
    ]
   },
   {
@@ -359,7 +359,7 @@
    ],
    "source": [
     "d = FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0, 1))\n",
-    "assert d.apply(c) == c[0,0] + c[1,1] - c[1,0] - c[0,1]\n",
+    "assert d.apply(c.center) == c[0,0] + c[1,1] - c[1,0] - c[0,1]\n",
     "d.visualize()"
    ]
   },
@@ -370,8 +370,9 @@
    "outputs": [],
    "source": [
     "v3 = ps.fields(\"v(3): [3D]\")\n",
-    "assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(v3) == \\\n",
-    "       sp.Matrix([v3[1,0,0](i) - v3[0,0,0](i) for i in range(*v3.index_shape)])"
+    "for i in range(*v3.index_shape):\n",
+    "    assert FiniteDifferenceStaggeredStencilDerivation(\"E\", 3, (0,)).apply(v3.center_vector[i]) == \\\n",
+    "           v3[1,0,0](i) - v3[0,0,0](i)"
    ]
   },
   {
-- 
GitLab