From ff68674d04c2f4ba0f90d94fabea7a8ab3d14477 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Fri, 5 Nov 2021 12:43:38 +0000
Subject: [PATCH] Relaxation rates should be floats

---
 lbmpy/methods/abstractlbmethod.py  | 5 +++++
 lbmpy/methods/creationfunctions.py | 6 +++---
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/lbmpy/methods/abstractlbmethod.py b/lbmpy/methods/abstractlbmethod.py
index aa6e6cc9..07b46b74 100644
--- a/lbmpy/methods/abstractlbmethod.py
+++ b/lbmpy/methods/abstractlbmethod.py
@@ -2,6 +2,7 @@ import abc
 from collections import namedtuple
 
 import sympy as sp
+from sympy.core.numbers import Zero
 
 from pystencils import Assignment, AssignmentCollection
 
@@ -52,6 +53,7 @@ class AbstractLbMethod(abc.ABC):
         """Returns a qxq diagonal matrix which contains the relaxation rate for each moment on the diagonal"""
         d = sp.zeros(len(self.relaxation_rates))
         for i in range(0, len(self.relaxation_rates)):
+            # note that 0.0 is converted to sp.Zero here. It is not possible to prevent this.
             d[i, i] = self.relaxation_rates[i]
         return d
 
@@ -104,6 +106,9 @@ class AbstractLbMethod(abc.ABC):
         for relaxation_rate in rr:
             if relaxation_rate not in unique_relaxation_rates:
                 relaxation_rate = sp.sympify(relaxation_rate)
+                # special treatment for zero, sp.Zero would be an integer ..
+                if isinstance(relaxation_rate, Zero):
+                    relaxation_rate = 0.0
                 if not isinstance(relaxation_rate, sp.Symbol):
                     rt_symbol = sp.Symbol(f"rr_{len(subexpressions)}")
                     subexpressions[relaxation_rate] = rt_symbol
diff --git a/lbmpy/methods/creationfunctions.py b/lbmpy/methods/creationfunctions.py
index 32b23624..f05bef51 100644
--- a/lbmpy/methods/creationfunctions.py
+++ b/lbmpy/methods/creationfunctions.py
@@ -605,11 +605,11 @@ def _get_relaxation_info_dict(relaxation_rates, nested_moments, dim):
         for group in nested_moments:
             for moment in group:
                 if get_order(moment) <= 1:
-                    result[moment] = 0
+                    result[moment] = 0.0
                 elif is_shear_moment(moment, dim):
                     result[moment] = relaxation_rates[0]
                 else:
-                    result[moment] = 1
+                    result[moment] = 1.0
 
     # if relaxation rate for each moment is specified they are all used
     if len(relaxation_rates) == number_of_moments:
@@ -634,7 +634,7 @@ def _get_relaxation_info_dict(relaxation_rates, nested_moments, dim):
                 next_rr = False
                 for moment in group:
                     if get_order(moment) <= 1:
-                        result[moment] = 0
+                        result[moment] = 0.0
                     elif is_shear_moment(moment, dim):
                         result[moment] = shear_rr
                     elif is_bulk_moment(moment, dim):
-- 
GitLab