Commit b13958d5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Extraction of factor 1/2

parent 37443525
Pipeline #28569 waiting for manual action with stage
in 67 minutes and 47 seconds
......@@ -208,7 +208,9 @@ class CenteredCumulantBasedLbMethod(AbstractLbMethod):
def get_collision_rule(self, pre_simplification=False):
"""Returns an LbmCollisionRule i.e. an equation collection with a reference to the method.
This collision rule defines the collision operator."""
return self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True)
ac = self._centered_cumulant_collision_rule(self._cumulant_to_relaxation_info_dict, None, pre_simplification, True)
ac = ac.new_without_unused_subexpressions()
return LbmCollisionRule(self, ac.main_assignments, ac.subexpressions)
# ------------------------------- Internals --------------------------------------------
......
......@@ -91,6 +91,23 @@ class PdfsToCentralMomentsByMatrix(AbstractCentralMomentTransform):
# end class PdfsToCentralMomentsByMatrix
class ExtractOneHalf:
"""
Pseudo-Simplification to instruct the FastCentralMomentTransform to extract
the factor 1/2 to a subexpression, to hide it from sympy. Otherwise, sympy will
distribute it across the sums, producing unnecessary multiplications.
"""
def __init__(self, one_half_proxy=sp.Symbol('half')):
self._symbol = one_half_proxy
@property
def symbol(self):
return self._symbol
def __call__(self, ac):
return ac
class FastCentralMomentTransform(AbstractCentralMomentTransform):
def __init__(self, stencil, moment_exponents, shift_velocity):
......@@ -149,13 +166,18 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
simplification=True, subexpression_base='sub_k_to_f'):
if simplification and not isinstance(simplification, SimplificationStrategy):
simplification = self._default_simplification
simplification.add(ExtractOneHalf())
raw_equations = self.mat_transform.backward_transform(
pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False)
raw_equations = raw_equations.new_without_subexpressions()
symbol_gen = SymbolGen(subexpression_base)
ac = self._split_backward_equations(raw_equations, symbol_gen)
if simplification:
extract_one_half = next(filter(lambda x: isinstance(x, ExtractOneHalf), simplification.rules))
else:
extract_one_half = None
ac = self._split_backward_equations(raw_equations, symbol_gen, extract_one_half=extract_one_half)
if simplification:
ac = simplification.apply(ac)
......@@ -172,7 +194,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
return simplification
def _split_backward_equations_recursive(self, assignment, all_subexpressions,
stencil_direction, subexp_symgen, known_coeffs_dict, step=0):
stencil_direction, subexp_symgen, known_coeffs_dict,
one_half, step=0):
# Base Case
if step == self.dim:
return assignment
......@@ -181,7 +204,8 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
u = self.shift_velocity[-1 - step]
d = stencil_direction[-1 - step]
one = sp.sympify(1)
one = sp.Integer(1)
two = sp.Integer(2)
# Factors to group terms by
grouping_factors = {
......@@ -192,7 +216,7 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
factors = grouping_factors[d]
# Common Integer factor to extract from all groups
common_factor = one if d == 0 else sp.Integer(2)
common_factor = one if d == 0 else two
# Proxy for factor grouping
v = sp.Symbol('v')
......@@ -224,23 +248,30 @@ class FastCentralMomentTransform(AbstractCentralMomentTransform):
# Recursively split the coefficient term
coeff_assignment = self._split_backward_equations_recursive(
coeff_assignment, all_subexpressions, stencil_direction, subexp_symgen,
known_coeffs_dict, step=step + 1)
known_coeffs_dict, one_half, step=step + 1)
all_subexpressions.append(coeff_assignment)
new_rhs += factors[k] * coeff_symb
if common_factor != one:
new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False)
if common_factor == two:
# new_rhs = sp.Mul(sp.Rational(1, common_factor), new_rhs, evaluate=False)
new_rhs = one_half * new_rhs
return Assignment(assignment.lhs, new_rhs)
def _split_backward_equations(self, backward_assignments, subexp_symgen):
all_subexpressions = []
def _split_backward_equations(self, backward_assignments, subexp_symgen, extract_one_half=None):
if extract_one_half is not None:
one_half_proxy = extract_one_half.symbol
all_subexpressions = [Assignment(one_half_proxy, sp.Rational(1,2))]
else:
one_half_proxy = sp.Rational(1,2)
all_subexpressions = []
split_main_assignments = []
known_coeffs_dict = dict()
for asm, stencil_dir in zip(backward_assignments, self.stencil):
split_asm = self._split_backward_equations_recursive(
asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict)
asm, all_subexpressions, stencil_dir, subexp_symgen, known_coeffs_dict, one_half_proxy)
split_main_assignments.append(split_asm)
ac = AssignmentCollection(split_main_assignments, subexpressions=all_subexpressions,
subexpression_symbol_generator=subexp_symgen)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment