diff --git a/lbmpy/moment_transforms/centralmomenttransforms.py b/lbmpy/moment_transforms/centralmomenttransforms.py
index 3d6ec0f485ad2dd1f808983494a031a0dde57302..700445dd9a6afba5d7aa7c37cbaa90e7b4295f9f 100644
--- a/lbmpy/moment_transforms/centralmomenttransforms.py
+++ b/lbmpy/moment_transforms/centralmomenttransforms.py
@@ -268,7 +268,7 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
         ac = AssignmentCollection(main_assignments, subexpressions=subexpressions,
                                   subexpression_symbol_generator=symbol_gen)
         if simplification:
-            ac = self._simplify_lower_order_moments(ac, monomial_symbol_base)
+            ac = self._simplify_lower_order_moments(ac, monomial_symbol_base, return_monomials)
             ac = simplification.apply(ac)
         return ac
 
@@ -335,14 +335,19 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
             'backward': backward_simp
         }
 
-    def _simplify_lower_order_moments(self, ac, moment_base):
+    def _simplify_lower_order_moments(self, ac, moment_base, search_in_main_assignments):
         if self.cqe is None:
             return ac
 
-        f_to_cm_dict = ac.main_assignments_dict
-        f_to_cm_dict_reduced = ac.new_without_subexpressions().main_assignments_dict
-
         moment_symbols = [sq_sym(moment_base, e) for e in moments_up_to_order(1, dim=self.dim)]
+
+        if search_in_main_assignments:
+            f_to_cm_dict = ac.main_assignments_dict
+            f_to_cm_dict_reduced = ac.new_without_subexpressions().main_assignments_dict
+        else:
+            f_to_cm_dict = ac.subexpressions_dict
+            f_to_cm_dict_reduced = ac.new_without_subexpressions(moment_symbols).subexpressions_dict
+
         cqe_subs = self.cqe.new_without_subexpressions().main_assignments_dict
         for m in moment_symbols:
             m_eq = fast_subs(fast_subs(f_to_cm_dict_reduced[m], cqe_subs), cqe_subs)
@@ -351,8 +356,12 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
                 m_eq = subs_additive(m_eq, cqe_sym, cqe_exp)
             f_to_cm_dict[m] = m_eq
 
-        main_assignments = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()]
-        return ac.copy(main_assignments=main_assignments)
+        if search_in_main_assignments:
+            main_assignments = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()]
+            return ac.copy(main_assignments=main_assignments)
+        else:
+            subexpressions = [Assignment(lhs, rhs) for lhs, rhs in f_to_cm_dict.items()]
+            return ac.copy(subexpressions=subexpressions)
 
     def _split_backward_equations_recursive(self, assignment, all_subexpressions,
                                             stencil_direction, subexp_symgen, known_coeffs_dict,
diff --git a/lbmpy/simplificationfactory.py b/lbmpy/simplificationfactory.py
index cbc58565bef3f5cc51921899afe85623bbbcec8c..ad89fe60b99faf946e3d36171412dbfed800a14f 100644
--- a/lbmpy/simplificationfactory.py
+++ b/lbmpy/simplificationfactory.py
@@ -2,6 +2,7 @@ import sympy as sp
 
 from lbmpy.innerloopsplit import create_lbm_split_groups
 from lbmpy.methods.momentbased.momentbasedmethod import MomentBasedLbMethod
+from lbmpy.methods.momentbased.centralmomentbasedmethod import CentralMomentBasedLbMethod
 from lbmpy.methods.centeredcumulant import CenteredCumulantBasedLbMethod
 from lbmpy.methods.momentbased.momentbasedsimplifications import (
     factor_density_after_factoring_relaxation_times, factor_relaxation_rates,
@@ -22,6 +23,8 @@ def create_simplification_strategy(lb_method, split_inner_loop=False):
             else:
                 # General MRT methods with population-space collision
                 return _mrt_population_space_simplification(split_inner_loop)
+    elif isinstance(lb_method, CentralMomentBasedLbMethod):
+        return _moment_space_simplification(split_inner_loop)
     elif isinstance(lb_method, CenteredCumulantBasedLbMethod):
         return _moment_space_simplification(split_inner_loop)
     else: