diff --git a/tests/lbm/codegen/FluctuatingMRT.py b/tests/lbm/codegen/FluctuatingMRT.py index 66185e171ec1542a64f11bcb32c9ec7bd573429c..eed48db71b1ae8d56c71382dd8d40fe23d050116 100644 --- a/tests/lbm/codegen/FluctuatingMRT.py +++ b/tests/lbm/codegen/FluctuatingMRT.py @@ -13,6 +13,9 @@ with CodeGeneration() as ctx: force_field, vel_field = ps.fields("force(3), velocity(3): [3D]", layout='fzyx') def rr_getter(moment_group): + """Maps a group of moments to a relaxation rate (shear, bulk, even, odd) + in the 4 relaxation time thermalized LB model + """ is_shear = [is_shear_moment(m, 3) for m in moment_group] is_bulk = [is_bulk_moment(m, 3) for m in moment_group] order = [get_order(m) for m in moment_group] @@ -20,24 +23,24 @@ with CodeGeneration() as ctx: order = order[0] if order < 2: - return 0.0 + return [0] * len(moment_group) elif any(is_bulk): assert all(is_bulk) - return sp.Symbol("omega_bulk") + return [sp.Symbol("omega_bulk")] * len(moment_group) elif any(is_shear): assert all(is_shear) - return sp.Symbol("omega_shear") + return [sp.Symbol("omega_shear")] * len(moment_group) elif order % 2 == 0: assert order > 2 - return sp.Symbol("omega_even") + return [sp.Symbol("omega_even")] * len(moment_group) else: - return sp.Symbol("omega_odd") + return [sp.Symbol("omega_odd")] * len(moment_group) method = create_mrt_orthogonal( stencil=LBStencil(Stencil.D3Q19), compressible=True, weighted=True, - relaxation_rate_getter=rr_getter, + relaxation_rates=rr_getter, force_model=Guo(force_field.center_vector) )