From fcdfc1274e31d65eaad377965722dc6edbc63e8c Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Sun, 24 Nov 2019 09:08:33 +0100
Subject: [PATCH] Remove unneeded ghost layers from staggered kernel

---
 pystencils/kernelcreation.py                 | 23 ++++++++++++++++----
 pystencils_tests/test_staggered_diffusion.py |  5 +++--
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index a87cb58c8..3d9da0ef5 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -333,16 +333,32 @@ def create_staggered_kernel_2(assignments, **kwargs):
 
     final_assignments = []
 
+    # find out whether any of the ghost layers is not needed
+    common_exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
+    for direction in staggered_field.staggered_stencil:
+        exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
+        for elementary_direction in direction:
+            exclusions.remove(inverse_direction_string(elementary_direction))
+        common_exclusions.intersection_update(exclusions)
+    ghost_layers = [[1, 1] for d in range(dim)]
+    for direction in common_exclusions:
+        direction = direction_string_to_offset(direction)
+        for d, s in enumerate(direction):
+            if s == 1:
+                ghost_layers[d][1] = 0
+            elif s == -1:
+                ghost_layers[d][0] = 0
+
     def condition(direction):
         """exclude those staggered points that correspond to fluxes between ghost cells"""
-        exclusions = set(["E", "W", "N", "S"])
-        if dim == 3:
-            exclusions.update("T", "B")
+        exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
 
         for elementary_direction in direction:
             exclusions.remove(inverse_direction_string(elementary_direction))
         conditions = []
         for e in exclusions:
+            if e in common_exclusions:
+                continue
             offset = direction_string_to_offset(e)
             for i, o in enumerate(offset):
                 if o == 1:
@@ -357,6 +373,5 @@ def create_staggered_kernel_2(assignments, **kwargs):
         last_conditional = Conditional(condition(direction), Block(sp_assignments))
         final_assignments.append(last_conditional)
 
-    ghost_layers = [(1, 0)] * dim
     ast = create_kernel(final_assignments, ghost_layers=ghost_layers, **kwargs)
     return ast
diff --git a/pystencils_tests/test_staggered_diffusion.py b/pystencils_tests/test_staggered_diffusion.py
index 0171db15c..4d0fb294e 100644
--- a/pystencils_tests/test_staggered_diffusion.py
+++ b/pystencils_tests/test_staggered_diffusion.py
@@ -1,5 +1,6 @@
 import pystencils as ps
 import numpy as np
+import sympy as sp
 
 
 class TestDiffusion:
@@ -20,8 +21,8 @@ class TestDiffusion:
         xY_staggered = - c[-1, 1] + c[0, 0]
 
         jj = j.staggered_access
-        divergence = -1 * D / (1 + np.sqrt(2) if j.index_shape[0] == 4 else 1) * \
-            sum([jj(d) / np.linalg.norm(ps.stencil.direction_string_to_offset(d)) for d in j.staggered_stencil
+        divergence = -1 * D / (1 + sp.sqrt(2) if j.index_shape[0] == 4 else 1) * \
+            sum([jj(d) / sp.Matrix(ps.stencil.direction_string_to_offset(d)).norm() for d in j.staggered_stencil
                 + [ps.stencil.inverse_direction_string(d) for d in j.staggered_stencil]])
 
         update = [ps.Assignment(c.center, c.center + dt * divergence)]
-- 
GitLab