From 6dab27553100de0dbcc63af3dc9fe7f22e76d813 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Sat, 23 Nov 2019 13:03:01 +0100
Subject: [PATCH] staggered diffusion test improvements

---
 pystencils/kernelcreation.py                 | 15 +++++++++++++--
 pystencils_tests/test_staggered_diffusion.py | 11 ++++-------
 2 files changed, 17 insertions(+), 9 deletions(-)

diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 198926e7f..a87cb58c8 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -6,6 +6,7 @@ import sympy as sp
 from pystencils.assignment import Assignment
 from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
 from pystencils.cpu.vectorization import vectorize
+from pystencils.field import Field
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.stencil import direction_string_to_offset, inverse_direction_string
@@ -187,8 +188,18 @@ def create_indexed_kernel(assignments,
         raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
 
 
-def create_staggered_kernel(staggered_field, expressions, subexpressions=(), target='cpu',
-                            gpu_exclusive_conditions=False, **kwargs):
+def create_staggered_kernel(*args, **kwargs):
+    """Kernel that updates a staggered field. Dispatches to either create_staggered_kernel_1 or
+       create_staggered_kernel_2 depending on the argument types.
+    """
+    if 'staggered_field' in kwargs or type(args[0]) is Field:
+        return create_staggered_kernel_1(*args, **kwargs)
+    else:
+        return create_staggered_kernel_2(*args, **kwargs)
+
+
+def create_staggered_kernel_1(staggered_field, expressions, subexpressions=(), target='cpu',
+                              gpu_exclusive_conditions=False, **kwargs):
     """Kernel that updates a staggered field.
 
     .. image:: /img/staggered_grid.svg
diff --git a/pystencils_tests/test_staggered_diffusion.py b/pystencils_tests/test_staggered_diffusion.py
index 458e0fd79..0171db15c 100644
--- a/pystencils_tests/test_staggered_diffusion.py
+++ b/pystencils_tests/test_staggered_diffusion.py
@@ -21,8 +21,8 @@ class TestDiffusion:
 
         jj = j.staggered_access
         divergence = -1 * D / (1 + np.sqrt(2) if j.index_shape[0] == 4 else 1) * \
-            sum([jj(d) for d in j.staggered_stencil +
-                [ps.stencil.inverse_direction_string(d) for d in j.staggered_stencil]])
+            sum([jj(d) / np.linalg.norm(ps.stencil.direction_string_to_offset(d)) for d in j.staggered_stencil
+                + [ps.stencil.inverse_direction_string(d) for d in j.staggered_stencil]])
 
         update = [ps.Assignment(c.center, c.center + dt * divergence)]
         flux = [ps.Assignment(j.staggered_access("W"), x_staggered),
@@ -31,7 +31,7 @@ class TestDiffusion:
             flux += [ps.Assignment(j.staggered_access("SW"), xy_staggered),
                      ps.Assignment(j.staggered_access("NW"), xY_staggered)]
 
-        staggered_kernel = ps.kernelcreation.create_staggered_kernel_2(flux, target=dh.default_target).compile()
+        staggered_kernel = ps.create_staggered_kernel(flux, target=dh.default_target).compile()
         div_kernel = ps.create_kernel(update, target=dh.default_target).compile()
 
         def time_loop(steps):
@@ -57,10 +57,7 @@ class TestDiffusion:
                 r = np.array([x, y]) - L[0] / 2 + 0.5
                 reference[x, y] = (4 * np.pi * D * T)**(-dh.dim / 2) * np.exp(-np.dot(r, r) / (4 * D * T)) * (2**dh.dim)
 
-        if num_neighbors == 2:
-            assert np.abs(dh.gather_array(c.name) - reference).max() < 1e-3
-        else:
-            assert np.abs(dh.gather_array(c.name) - reference).max() < 1e-2
+        assert np.abs(dh.gather_array(c.name) - reference).max() < 5e-4
 
     def test_diffusion_2(self):
         self._run(2)
-- 
GitLab