Commit b13958d5 by Frederik Hennig

### Extraction of factor 1/2

parent 37443525
Pipeline #28569 waiting for manual action with stage
in 67 minutes and 47 seconds
 ... ... @@ -208,7 +208,9 @@ class CenteredCumulantBasedLbMethod(AbstractLbMethod): def get_collision_rule(self, pre_simplification=False): """Returns an LbmCollisionRule i.e. an equation collection with a reference to the method. This collision rule defines the collision operator.""" return self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True) ac = self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True) ac = ac.new_without_unused_subexpressions() return LbmCollisionRule(self, ac.main_assignments, ac.subexpressions) # ------------------------------- Internals -------------------------------------------- ... ...
 ... ... @@ -91,6 +91,23 @@ class PdfsToCentralMomentsByMatrix(AbstractCentralMomentTransform): # end class PdfsToCentralMomentsByMatrix class ExtractOneHalf: """ Pseudo-Simplification to instruct the FastCentralMomentTransform to extract the factor 1/2 to a subexpression, to hide it from sympy. Otherwise, sympy will distribute it across the sums, producing unnecessary multiplications. """ def __init__(self, one_half_proxy=sp.Symbol('half')): self._symbol = one_half_proxy @property def symbol(self): return self._symbol def __call__(self, ac): return ac class FastCentralMomentTransform(AbstractCentralMomentTransform): def __init__(self, stencil, moment_exponents, shift_velocity): ... ... @@ -149,13 +166,18 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): simplification=True, subexpression_base='sub_k_to_f'): if simplification and not isinstance(simplification, SimplificationStrategy): simplification = self._default_simplification simplification.add(ExtractOneHalf()) raw_equations = self.mat_transform.backward_transform( pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False) raw_equations = raw_equations.new_without_subexpressions() symbol_gen = SymbolGen(subexpression_base) ac = self._split_backward_equations(raw_equations, symbol_gen) if simplification: extract_one_half = next(filter(lambda x: isinstance(x, ExtractOneHalf), simplification.rules)) else: extract_one_half = None ac = self._split_backward_equations(raw_equations, symbol_gen, extract_one_half=extract_one_half) if simplification: ac = simplification.apply(ac) ... ... @@ -172,7 +194,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): return simplification def _split_backward_equations_recursive(self, assignment, all_subexpressions, stencil_direction, subexp_symgen, known_coeffs_dict, step=0): stencil_direction, subexp_symgen, known_coeffs_dict, one_half, step=0): # Base Case if step == self.dim: return assignment ... ... @@ -181,7 +204,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): u = self.shift_velocity[-1 - step] d = stencil_direction[-1 - step] one = sp.sympify(1) one = sp.Integer(1) two = sp.Integer(2) # Factors to group terms by grouping_factors = { ... ... @@ -192,7 +216,7 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): factors = grouping_factors[d] # Common Integer factor to extract from all groups common_factor = one if d == 0 else sp.Integer(2) common_factor = one if d == 0 else two # Proxy for factor grouping v = sp.Symbol('v') ... ... @@ -224,23 +248,30 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform): # Recursively split the coefficient term coeff_assignment = self._split_backward_equations_recursive( coeff_assignment, all_subexpressions, stencil_direction, subexp_symgen, known_coeffs_dict, step=step + 1) known_coeffs_dict, one_half, step=step + 1) all_subexpressions.append(coeff_assignment) new_rhs += factors[k] * coeff_symb if common_factor != one: new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False) if common_factor == two: # new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False) new_rhs = one_half * new_rhs return Assignment(assignment.lhs, new_rhs) def _split_backward_equations(self, backward_assignments, subexp_symgen): all_subexpressions = [] def _split_backward_equations(self, backward_assignments, subexp_symgen, extract_one_half=None): if extract_one_half is not None: one_half_proxy = extract_one_half.symbol all_subexpressions = [Assignment(one_half_proxy, sp.Rational(1,2))] else: one_half_proxy = sp.Rational(1,2) all_subexpressions = [] split_main_assignments = [] known_coeffs_dict = dict() for asm, stencil_dir in zip(backward_assignments, self.stencil): split_asm = self._split_backward_equations_recursive( asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict) asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict, one_half_proxy) split_main_assignments.append(split_asm) ac = AssignmentCollection(split_main_assignments, subexpressions=all_subexpressions, subexpression_symbol_generator=subexp_symgen) ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!