From efadec2e68e08c4091a41012320b28583e87039c Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Fri, 6 Dec 2019 17:20:55 +0100 Subject: [PATCH] Minor improvements to FiniteDifferenceStaggeredStencilDerivation --- pystencils/fd/derivation.py | 9 +++++---- pystencils_tests/test_fd_derivation.ipynb | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pystencils/fd/derivation.py b/pystencils/fd/derivation.py index 5a9d79e8f..af1f4e387 100644 --- a/pystencils/fd/derivation.py +++ b/pystencils/fd/derivation.py @@ -303,11 +303,12 @@ class FiniteDifferenceStaggeredStencilDerivation: zero_counts[zero_count].append(weights) best = zero_counts[max(zero_counts.keys())] if len(best) > 1: # if there are multiple, pick the one that contains a nonzero center weight - center = [tuple(p + pos) for p in points].index((0, 0, 0)) + center = [tuple(p + pos) for p in points].index((0, 0, 0)[:dim]) best = [b for b in best if b[center] != 0] - if len(best) > 1: - raise NotImplementedError("more than one suitable set of weights found, don't know how to proceed") - weights = best[0] + if len(best) > 1: # if there are still multiple, they are equivalent, so we average + weights = sp.Add(*[sp.Matrix(b) for b in best]) / len(best) + else: + weights = best[0] assert weights points_tuple = tuple([tuple(p + pos) for p in points]) diff --git a/pystencils_tests/test_fd_derivation.ipynb b/pystencils_tests/test_fd_derivation.ipynb index f992c8ac5..1be42210e 100644 --- a/pystencils_tests/test_fd_derivation.ipynb +++ b/pystencils_tests/test_fd_derivation.ipynb @@ -332,6 +332,9 @@ "assert FiniteDifferenceStaggeredStencilDerivation(\"T\", 3, (2,)).apply(c3) == c3[0, 0, 1] - c3[0, 0, 0]\n", "assert FiniteDifferenceStaggeredStencilDerivation(\"B\", 3, (2,)).apply(c3) == c3[0, 0, 0] - c3[0, 0, -1]\n", "\n", + "assert FiniteDifferenceStaggeredStencilDerivation(\"S\", 2, (0,)).apply(c) == \\\n", + " (c[1, 0] + c[1, -1] - c[-1, 0] - c[-1, -1])/4\n", + "\n", "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (0,)).apply(c) + \\\n", " FiniteDifferenceStaggeredStencilDerivation(\"NE\", 2, (1,)).apply(c) == c[1, 1] - c[0, 0]\n", "assert FiniteDifferenceStaggeredStencilDerivation(\"NE\", 3, (0,)).apply(c3) + \\\n", -- GitLab