From 1ebe2216d8abac99bd5f659219e5ddf59a1d71cd Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Thu, 28 Nov 2019 14:55:24 +0100
Subject: [PATCH] create_mrt_orthogonal: use the same viscous moments as in
 literature

---
 lbmpy/methods/creationfunctions.py            | 10 +++++++-
 lbmpy/moments.py                              |  2 +-
 .../test_momentbased_methods_equilibrium.py   | 23 +++++++++++++++++++
 3 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/lbmpy/methods/creationfunctions.py b/lbmpy/methods/creationfunctions.py
index 3fc7ab08..a2d8f330 100644
--- a/lbmpy/methods/creationfunctions.py
+++ b/lbmpy/methods/creationfunctions.py
@@ -381,8 +381,9 @@ def create_mrt_orthogonal(stencil, relaxation_rate_getter=None, maxwellian_momen
                `mrt_orthogonal_modes_literature`. If this argument is not provided, Gram-Schmidt orthogonalization of
                the default modes is performed.
     """
+    dim = len(stencil[0])
     if relaxation_rate_getter is None:
-        relaxation_rate_getter = default_relaxation_rate_names(dim=len(stencil[0]))
+        relaxation_rate_getter = default_relaxation_rate_names(dim=dim)
     if isinstance(stencil, str):
         stencil = get_stencil(stencil)
 
@@ -396,6 +397,13 @@ def create_mrt_orthogonal(stencil, relaxation_rate_getter=None, maxwellian_momen
     moment_to_relaxation_rate_dict = OrderedDict()
     if not nested_moments:
         moments = get_default_moment_set_for_stencil(stencil)
+        x, y, z = MOMENT_SYMBOLS
+        if dim == 2:
+            diagonal_viscous_moments = [x ** 2 + y ** 2, x ** 2]
+        else:
+            diagonal_viscous_moments = [x ** 2 + y ** 2 + z ** 2, x ** 2, y ** 2 - z ** 2]
+        for i, d in enumerate(MOMENT_SYMBOLS[:dim]):
+            moments[moments.index(d**2)] = diagonal_viscous_moments[i]
         orthogonal_moments = gram_schmidt(moments, stencil, weights)
         orthogonal_moments_scaled = [e * common_denominator(e) for e in orthogonal_moments]
         nested_moments = list(sort_moments_into_groups_of_same_order(orthogonal_moments_scaled).values())
diff --git a/lbmpy/moments.py b/lbmpy/moments.py
index 793b9220..dd111808 100644
--- a/lbmpy/moments.py
+++ b/lbmpy/moments.py
@@ -425,7 +425,7 @@ def get_default_moment_set_for_stencil(stencil):
 
     all27_moments = moments_up_to_component_order(2, dim=3)
     if have_same_entries(stencil, get_stencil("D3Q27")):
-        return to_poly(all27_moments)
+        return sorted(to_poly(all27_moments), key=moment_sort_key)
     if have_same_entries(stencil, get_stencil("D3Q19")):
         non_matched_moments = [(1, 2, 2), (1, 1, 2), (2, 2, 2), (1, 1, 1)]
         moments19 = set(all27_moments) - set(extend_moments_with_permutations(non_matched_moments))
diff --git a/lbmpy_tests/test_momentbased_methods_equilibrium.py b/lbmpy_tests/test_momentbased_methods_equilibrium.py
index 5aae948b..1e4b8cdb 100644
--- a/lbmpy_tests/test_momentbased_methods_equilibrium.py
+++ b/lbmpy_tests/test_momentbased_methods_equilibrium.py
@@ -8,6 +8,7 @@ import sympy as sp
 from lbmpy.creationfunctions import create_lb_method
 from lbmpy.maxwellian_equilibrium import discrete_maxwellian_equilibrium
 from lbmpy.methods import create_mrt_orthogonal, create_srt, create_trt, mrt_orthogonal_modes_literature
+from lbmpy.moments import is_bulk_moment, is_shear_moment
 from lbmpy.relaxationrates import get_shear_relaxation_rate
 from lbmpy.stencils import get_stencil
 
@@ -81,14 +82,36 @@ def test_mrt_orthogonal():
     m = create_mrt_orthogonal(get_stencil("D3Q19"), maxwellian_moments=True, weighted=True)
     assert m.is_weighted_orthogonal
 
+    m_ref = {}
+
     moments = mrt_orthogonal_modes_literature(get_stencil("D3Q15"), True, False)
     m = create_mrt_orthogonal(get_stencil("D3Q15"), maxwellian_moments=True, nested_moments=moments)
     assert m.is_weighted_orthogonal
+    m_ref[("D3Q15", True)] = m
 
     moments = mrt_orthogonal_modes_literature(get_stencil("D3Q19"), True, False)
     m = create_mrt_orthogonal(get_stencil("D3Q19"), maxwellian_moments=True, nested_moments=moments)
     assert m.is_weighted_orthogonal
+    m_ref[("D3Q19", True)] = m
 
     moments = mrt_orthogonal_modes_literature(get_stencil("D3Q27"), False, False)
     m = create_mrt_orthogonal(get_stencil("D3Q27"), maxwellian_moments=True, nested_moments=moments)
     assert m.is_orthogonal
+    m_ref[("D3Q27", False)] = m
+
+    for weighted in [True, False]:
+        for stencil in ["D2Q9", "D3Q15", "D3Q19", "D3Q27"]:
+            m = create_mrt_orthogonal(get_stencil(stencil), maxwellian_moments=True, weighted=weighted)
+            bulk_moments = set([mom for mom in m.moments if is_bulk_moment(mom, m.dim)])
+            shear_moments = set([mom for mom in m.moments if is_shear_moment(mom, m.dim)])
+            assert len(bulk_moments) == 1
+            assert len(shear_moments) == 1 + (m.dim - 2) + m.dim * (m.dim - 1) / 2
+
+            if (stencil, weighted) in m_ref:
+                ref = m_ref[(stencil, weighted)]
+                bulk_moments_lit = set([mom for mom in ref.moments if is_bulk_moment(mom, ref.dim)])
+                shear_moments_lit = set([mom for mom in ref.moments if is_shear_moment(mom, ref.dim)])
+
+                if stencil != "D3Q27":  # this one uses a different linear combination in literature
+                    assert shear_moments == shear_moments_lit
+                assert bulk_moments == bulk_moments_lit
-- 
GitLab