From d2e3c3a6c83b3548275c90f6cfe31f7c0ca24cb2 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Tue, 15 Jun 2021 05:26:28 +0000
Subject: [PATCH] FVM: Choose better stencil for derivative in flux for D3Q27

---
 pystencils/fd/derivation.py    |  8 +++++--
 pystencils/fd/finitevolumes.py | 22 ++++++++++++++++--
 pystencils_tests/test_fvm.py   | 42 ++++++++++++++++++++++++++++++++++
 3 files changed, 68 insertions(+), 4 deletions(-)

diff --git a/pystencils/fd/derivation.py b/pystencils/fd/derivation.py
index 868a6872e..e945d7610 100644
--- a/pystencils/fd/derivation.py
+++ b/pystencils/fd/derivation.py
@@ -228,9 +228,10 @@ class FiniteDifferenceStaggeredStencilDerivation:
         neighbor: the neighbor direction string or vector at whose staggered position to calculate the derivative
         dim: how many dimensions (2 or 3)
         derivative: a tuple of directions over which to perform derivatives
+        free_weights_prefix: a string to prefix to free weight symbols. If None, do not return free weights
     """
 
-    def __init__(self, neighbor, dim, derivative=tuple()):
+    def __init__(self, neighbor, dim, derivative=tuple(), free_weights_prefix=None):
         if type(neighbor) is str:
             neighbor = direction_string_to_offset(neighbor)
         if dim == 2:
@@ -281,7 +282,10 @@ class FiniteDifferenceStaggeredStencilDerivation:
 
             # if the weights are underdefined, we can choose the free symbols to find the sparsest stencil
             free_weights = set(itertools.chain(*[w.free_symbols for w in weights]))
-            if len(free_weights) > 0:
+            if free_weights_prefix is not None:
+                weights = [w.subs({fw: sp.Symbol(f"{free_weights_prefix}_{i}") for i, fw in enumerate(free_weights)})
+                           for w in weights]
+            elif len(free_weights) > 0:
                 zero_counts = defaultdict(list)
                 for values in itertools.product([-1, -sp.Rational(1, 2), 0, 1, sp.Rational(1, 2)],
                                                 repeat=len(free_weights)):
diff --git a/pystencils/fd/finitevolumes.py b/pystencils/fd/finitevolumes.py
index d8f46f394..25992dd8c 100644
--- a/pystencils/fd/finitevolumes.py
+++ b/pystencils/fd/finitevolumes.py
@@ -59,7 +59,10 @@ class FVM1stOrder:
 
         assert ps.FieldType.is_staggered(flux_field)
 
+        num = 0
+
         def discretize(term, neighbor):
+            nonlocal num
             if isinstance(term, sp.Matrix):
                 nw = term.applyfunc(lambda t: discretize(t, neighbor))
                 return nw
@@ -69,7 +72,9 @@ class FVM1stOrder:
             elif isinstance(term, ps.fd.Diff):
                 access, direction = get_access_and_direction(term)
 
-                fds = FDS(neighbor, access.field.spatial_dimensions, direction)
+                fds = FDS(neighbor, access.field.spatial_dimensions, direction,
+                          free_weights_prefix=f'fvm_free_{num}' if sp.Matrix(neighbor).dot(neighbor) > 2 else None)
+                num += 1
                 return fds.apply(access)
 
             if term.args:
@@ -91,7 +96,20 @@ class FVM1stOrder:
             directional_flux = fluxes[0] * int(neighbor[0])
             for i in range(1, self.dim):
                 directional_flux += fluxes[i] * int(neighbor[i])
-            discrete_flux = discretize(directional_flux, neighbor)
+            discrete_flux = sp.simplify(discretize(directional_flux, neighbor))
+            free_weights = [s for s in discrete_flux.atoms(sp.Symbol) if s.name.startswith('fvm_free_')]
+
+            if len(free_weights) > 0:
+                discrete_flux = discrete_flux.collect(discrete_flux.atoms(ps.field.Field.Access))
+                access_counts = defaultdict(list)
+                for values in itertools.product([-1, 0, 1],
+                                                repeat=len(free_weights)):
+                    subs = {free_weight: value for free_weight, value in zip(free_weights, values)}
+                    simp = discrete_flux.subs(subs)
+                    access_count = len(simp.atoms(ps.field.Field.Access))
+                    access_counts[access_count].append(simp)
+                best_count = min(access_counts.keys())
+                discrete_flux = sum(access_counts[best_count]) / len(access_counts[best_count])
             discrete_fluxes.append(discrete_flux / sp.Matrix(neighbor).norm())
 
         if flux_field.index_dimensions > 1:
diff --git a/pystencils_tests/test_fvm.py b/pystencils_tests/test_fvm.py
index c863c8f45..42127cc76 100644
--- a/pystencils_tests/test_fvm.py
+++ b/pystencils_tests/test_fvm.py
@@ -282,3 +282,45 @@ def test_ek(stencil):
         assert a.rhs == b.rhs
 
 # TODO: test source
+
+
+@pytest.mark.parametrize("stencil", ["D2Q5", "D2Q9", "D3Q7", "D3Q19", "D3Q27"])
+@pytest.mark.parametrize("derivative", [0, 1])
+def test_flux_stencil(stencil, derivative):
+    L = (40, ) * int(stencil[1])
+    dh = ps.create_data_handling(L, periodicity=True, default_target='cpu')
+    c = dh.add_array('c', values_per_cell=1)
+    j = dh.add_array('j', values_per_cell=int(stencil[3:]) // 2, field_type=ps.FieldType.STAGGERED_FLUX)
+
+    def Gradient(f):
+        return sp.Matrix([ps.fd.diff(f, i) for i in range(dh.dim)])
+
+    eq = [sp.Matrix([sp.Symbol(f"a_{i}") * c.center for i in range(dh.dim)]), Gradient(c)][derivative]
+    disc = ps.fd.FVM1stOrder(c, flux=eq)
+
+    # check the continuity
+    continuity_assignments = disc.discrete_continuity(j)
+    assert [len(a.rhs.atoms(ps.field.Field.Access)) for a in continuity_assignments] == \
+           [int(stencil[3:])] * len(continuity_assignments)
+
+    # check the flux
+    flux_assignments = disc.discrete_flux(j)
+    assert [len(a.rhs.atoms(ps.field.Field.Access)) for a in flux_assignments] == [2] * len(flux_assignments)
+
+
+@pytest.mark.parametrize("stencil", ["D2Q5", "D2Q9", "D3Q7", "D3Q19", "D3Q27"])
+def test_source_stencil(stencil):
+    L = (40, ) * int(stencil[1])
+    dh = ps.create_data_handling(L, periodicity=True, default_target='cpu')
+    c = dh.add_array('c', values_per_cell=1)
+    j = dh.add_array('j', values_per_cell=int(stencil[3:]) // 2, field_type=ps.FieldType.STAGGERED_FLUX)
+
+    continuity_ref = ps.fd.FVM1stOrder(c).discrete_continuity(j)
+
+    for eq in [c.center] + [ps.fd.diff(c, i) for i in range(dh.dim)]:
+        disc = ps.fd.FVM1stOrder(c, source=eq)
+        diff = sp.simplify(disc.discrete_continuity(j)[0].rhs - continuity_ref[0].rhs)
+        if type(eq) is ps.field.Field.Access:
+            assert len(diff.atoms(ps.field.Field.Access)) == 1
+        else:
+            assert len(diff.atoms(ps.field.Field.Access)) == 2
-- 
GitLab