From a2f07b9b4594525a9a6f8a923082154b0d9811c6 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Thu, 28 Nov 2019 13:56:15 +0100
Subject: [PATCH] lbmpy.moments.is_{bulk,shear}_moment

shear previously ignored the diagonal moments and bulk didn't exist
---
 lbmpy/methods/creationfunctions.py |  2 +-
 lbmpy/moments.py                   | 39 ++++++++++++++++++++++++------
 lbmpy/relaxationrates.py           |  6 ++---
 lbmpy_tests/test_moments.py        | 27 +++++++++++++++++++++
 4 files changed, 62 insertions(+), 12 deletions(-)

diff --git a/lbmpy/methods/creationfunctions.py b/lbmpy/methods/creationfunctions.py
index 755a68cb..3fc7ab08 100644
--- a/lbmpy/methods/creationfunctions.py
+++ b/lbmpy/methods/creationfunctions.py
@@ -382,7 +382,7 @@ def create_mrt_orthogonal(stencil, relaxation_rate_getter=None, maxwellian_momen
                the default modes is performed.
     """
     if relaxation_rate_getter is None:
-        relaxation_rate_getter = default_relaxation_rate_names()
+        relaxation_rate_getter = default_relaxation_rate_names(dim=len(stencil[0]))
     if isinstance(stencil, str):
         stencil = get_stencil(stencil)
 
diff --git a/lbmpy/moments.py b/lbmpy/moments.py
index f6cf1897..793b9220 100644
--- a/lbmpy/moments.py
+++ b/lbmpy/moments.py
@@ -273,14 +273,37 @@ def non_aliased_moment(moment_tuple: Sequence[int]) -> Tuple[int, ...]:
     return tuple(result)
 
 
-def is_shear_moment(moment):
-    """Shear moments in 3D are: x*y, x*z and y*z - in 2D its only x*y"""
-    if type(moment) is tuple:
-        moment = exponent_to_polynomial_representation(moment)
-    return moment in is_shear_moment.shear_moments
-
-
-is_shear_moment.shear_moments = set([c[0] * c[1] for c in itertools.combinations(MOMENT_SYMBOLS, 2)])
+def is_bulk_moment(moment, dim):
+    """The bulk moment is x**2+y**2+z**2"""
+    if type(moment) is not tuple:
+        moment = polynomial_to_exponent_representation(moment)
+    quadratic = False
+    found = [0 for _ in range(dim)]
+    for prefactor, monomial in moment:
+        if sum(monomial) == 2:
+            quadratic = True
+            for i, exponent in enumerate(monomial):
+                if exponent == 2:
+                    found[i] += prefactor
+        elif sum(monomial) > 2:
+            return False
+    return quadratic and found != [0] * dim and len(set(found)) == 1
+
+
+def is_shear_moment(moment, dim):
+    """Shear moments are the quadratic polynomials except for the bulk moment.
+       Linear combinations with lower-order polynomials don't harm because these correspond to conserved moments."""
+    if is_bulk_moment(moment, dim):
+        return False
+    if type(moment) is not tuple:
+        moment = polynomial_to_exponent_representation(moment)
+    quadratic = False
+    for prefactor, monomial in moment:
+        if sum(monomial) == 2:
+            quadratic = True
+        elif sum(monomial) > 2:
+            return False
+    return quadratic
 
 
 @memorycache(maxsize=512)
diff --git a/lbmpy/relaxationrates.py b/lbmpy/relaxationrates.py
index 2e365162..d6baa278 100644
--- a/lbmpy/relaxationrates.py
+++ b/lbmpy/relaxationrates.py
@@ -31,7 +31,7 @@ def get_shear_relaxation_rate(method):
 
     relaxation_rates = set()
     for moment, relax_info in method.relaxation_info_dict.items():
-        if is_shear_moment(moment):
+        if is_shear_moment(moment, method.dim):
             relaxation_rates.add(relax_info.relaxation_rate)
     if len(relaxation_rates) == 1:
         return relaxation_rates.pop()
@@ -59,14 +59,14 @@ def relaxation_rate_scaling(omega, level_scale_factor):
     return omega / (omega / 2 + level_scale_factor * (1 - omega / 2))
 
 
-def default_relaxation_rate_names():
+def default_relaxation_rate_names(dim):
     next_index = [0]
 
     def result(moment_list):
         shear_moment_inside = False
         all_conserved_moments = True
         for m in moment_list:
-            if is_shear_moment(m):
+            if is_shear_moment(m, dim):
                 shear_moment_inside = True
             if not (get_order(m) == 0 or get_order(m) == 1):
                 all_conserved_moments = False
diff --git a/lbmpy_tests/test_moments.py b/lbmpy_tests/test_moments.py
index 79d96ef4..f6ae3cf2 100644
--- a/lbmpy_tests/test_moments.py
+++ b/lbmpy_tests/test_moments.py
@@ -73,3 +73,30 @@ def test_gram_schmidt_orthogonalization():
     orthogonal_moments = gram_schmidt(moments, stencil)
     pdfs_to_moments = moment_matrix(orthogonal_moments, stencil)
     assert (pdfs_to_moments * pdfs_to_moments.T).is_diagonal()
+
+
+def test_is_bulk_moment():
+    x, y, z = MOMENT_SYMBOLS
+    assert not is_bulk_moment(x, 2)
+    assert not is_bulk_moment(x ** 3, 2)
+    assert not is_bulk_moment(x * y, 2)
+    assert not is_bulk_moment(x ** 2, 2)
+    assert not is_bulk_moment(x ** 2 + y ** 2, 3)
+    assert is_bulk_moment(x ** 2 + y ** 2, 2)
+    assert is_bulk_moment(x ** 2 + y ** 2 + z ** 2, 3)
+    assert is_bulk_moment(x ** 2 + y ** 2 + x, 2)
+    assert is_bulk_moment(x ** 2 + y ** 2 + 1, 2)
+
+
+def test_is_shear_moment():
+    x, y, z = MOMENT_SYMBOLS
+    assert not is_shear_moment(x ** 3, 2)
+    assert not is_shear_moment(x, 2)
+    assert not is_shear_moment(x ** 2 + y ** 2, 2)
+    assert not is_shear_moment(x ** 2 + y ** 2 + z ** 2, 3)
+    assert is_shear_moment(x ** 2, 2)
+    assert is_shear_moment(x ** 2 - 1, 2)
+    assert is_shear_moment(x ** 2 - x, 2)
+    assert is_shear_moment(x * y, 2)
+    assert is_shear_moment(x * y - 1, 2)
+    assert is_shear_moment(x * y - x, 2)
-- 
GitLab