diff --git a/tests/lbm/codegen/FluctuatingMRT.py b/tests/lbm/codegen/FluctuatingMRT.py
index 66185e171ec1542a64f11bcb32c9ec7bd573429c..eed48db71b1ae8d56c71382dd8d40fe23d050116 100644
--- a/tests/lbm/codegen/FluctuatingMRT.py
+++ b/tests/lbm/codegen/FluctuatingMRT.py
@@ -13,6 +13,9 @@ with CodeGeneration() as ctx:
     force_field, vel_field = ps.fields("force(3), velocity(3): [3D]", layout='fzyx')
 
     def rr_getter(moment_group):
+        """Maps a group of moments to a relaxation rate (shear, bulk, even, odd)
+        in the 4 relaxation time thermalized LB model
+        """
         is_shear = [is_shear_moment(m, 3) for m in moment_group]
         is_bulk = [is_bulk_moment(m, 3) for m in moment_group]
         order = [get_order(m) for m in moment_group]
@@ -20,24 +23,24 @@ with CodeGeneration() as ctx:
         order = order[0]
 
         if order < 2:
-            return 0.0
+            return [0] * len(moment_group)
         elif any(is_bulk):
             assert all(is_bulk)
-            return sp.Symbol("omega_bulk")
+            return [sp.Symbol("omega_bulk")] * len(moment_group)
         elif any(is_shear):
             assert all(is_shear)
-            return sp.Symbol("omega_shear")
+            return [sp.Symbol("omega_shear")] * len(moment_group)
         elif order % 2 == 0:
             assert order > 2
-            return sp.Symbol("omega_even")
+            return [sp.Symbol("omega_even")] * len(moment_group)
         else:
-            return sp.Symbol("omega_odd")
+            return [sp.Symbol("omega_odd")] * len(moment_group)
 
     method = create_mrt_orthogonal(
         stencil=LBStencil(Stencil.D3Q19),
         compressible=True,
         weighted=True,
-        relaxation_rate_getter=rr_getter,
+        relaxation_rates=rr_getter,
         force_model=Guo(force_field.center_vector)
     )