diff --git a/lbmpy/moment_transforms/centralmomenttransforms.py b/lbmpy/moment_transforms/centralmomenttransforms.py index 3d6ec0f485ad2dd1f808983494a031a0dde57302..700445dd9a6afba5d7aa7c37cbaa90e7b4295f9f 100644 --- a/lbmpy/moment_transforms/centralmomenttransforms.py +++ b/lbmpy/moment_transforms/centralmomenttransforms.py @@ -268,7 +268,7 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): ac = AssignmentCollection(main_assignments, subexpressions=subexpressions, subexpression_symbol_generator=symbol_gen) if simplification: - ac = self._simplify_lower_order_moments(ac, monomial_symbol_base) + ac = self._simplify_lower_order_moments(ac, monomial_symbol_base, return_monomials) ac = simplification.apply(ac) return ac @@ -335,14 +335,19 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): 'backward': backward_simp } - def _simplify_lower_order_moments(self, ac, moment_base): + def _simplify_lower_order_moments(self, ac, moment_base, search_in_main_assignments): if self.cqe is None: return ac - f_to_cm_dict = ac.main_assignments_dict - f_to_cm_dict_reduced = ac.new_without_subexpressions().main_assignments_dict - moment_symbols = [sq_sym(moment_base, e) for e in moments_up_to_order(1, dim=self.dim)] + + if search_in_main_assignments: + f_to_cm_dict = ac.main_assignments_dict + f_to_cm_dict_reduced = ac.new_without_subexpressions().main_assignments_dict + else: + f_to_cm_dict = ac.subexpressions_dict + f_to_cm_dict_reduced = ac.new_without_subexpressions(moment_symbols).subexpressions_dict + cqe_subs = self.cqe.new_without_subexpressions().main_assignments_dict for m in moment_symbols: m_eq = fast_subs(fast_subs(f_to_cm_dict_reduced[m], cqe_subs), cqe_subs) @@ -351,8 +356,12 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): m_eq = subs_additive(m_eq, cqe_sym, cqe_exp) f_to_cm_dict[m] = m_eq - main_assignments = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()] - return ac.copy(main_assignments=main_assignments) + if search_in_main_assignments: + main_assignments = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()] + return ac.copy(main_assignments=main_assignments) + else: + subexpressions = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()] + return ac.copy(subexpressions=subexpressions) def _split_backward_equations_recursive(self, assignment, all_subexpressions, stencil_direction, subexp_symgen, known_coeffs_dict, diff --git a/lbmpy/simplificationfactory.py b/lbmpy/simplificationfactory.py index cbc58565bef3f5cc51921899afe85623bbbcec8c..ad89fe60b99faf946e3d36171412dbfed800a14f 100644 --- a/lbmpy/simplificationfactory.py +++ b/lbmpy/simplificationfactory.py @@ -2,6 +2,7 @@ import sympy as sp from lbmpy.innerloopsplit import create_lbm_split_groups from lbmpy.methods.momentbased.momentbasedmethod import MomentBasedLbMethod +from lbmpy.methods.momentbased.centralmomentbasedmethod import CentralMomentBasedLbMethod from lbmpy.methods.centeredcumulant import CenteredCumulantBasedLbMethod from lbmpy.methods.momentbased.momentbasedsimplifications import ( factor_density_after_factoring_relaxation_times, factor_relaxation_rates, @@ -22,6 +23,8 @@ def create_simplification_strategy(lb_method, split_inner_loop=False): else: # General MRT methods with population-space collision return _mrt_population_space_simplification(split_inner_loop) + elif isinstance(lb_method, CentralMomentBasedLbMethod): + return _moment_space_simplification(split_inner_loop) elif isinstance(lb_method, CenteredCumulantBasedLbMethod): return _moment_space_simplification(split_inner_loop) else: