Commit 81c6c262 authored by Martin Bauer's avatar Martin Bauer
Browse files

lbmpy: Simplifications change

- also for default MRT methods its better to use symbolic RR's instead of inserting numeric values for them
- but then the simplification strategy developed for SRT/TRT has to be disabled
- this commit introduces two new options:
     - simplification='auto' or a custom SimplificationStrategy
     - 'keep_rrs_symbolic'=True, by default now RRs are left symbolic
parent 16c96cf6
...@@ -198,6 +198,7 @@ from pystencils import Assignment, AssignmentCollection, create_kernel ...@@ -198,6 +198,7 @@ from pystencils import Assignment, AssignmentCollection, create_kernel
from pystencils.cache import disk_cache_no_fallback from pystencils.cache import disk_cache_no_fallback
from pystencils.data_types import collate_types from pystencils.data_types import collate_types
from pystencils.field import Field, get_layout_of_array from pystencils.field import Field, get_layout_of_array
from pystencils.simp import sympy_cse
from pystencils.stencil import have_same_entries from pystencils.stencil import have_same_entries
...@@ -300,8 +301,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs): ...@@ -300,8 +301,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
lb_method = create_lb_method(**params) lb_method = create_lb_method(**params)
split_inner_loop = 'split' in opt_params and opt_params['split'] split_inner_loop = 'split' in opt_params and opt_params['split']
simplification = create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False,
split_inner_loop=split_inner_loop)
cqc = lb_method.conserved_quantity_computation cqc = lb_method.conserved_quantity_computation
rho_in = params['density_input'] rho_in = params['density_input']
...@@ -312,17 +311,23 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs): ...@@ -312,17 +311,23 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
if rho_in is not None and isinstance(rho_in, Field): if rho_in is not None and isinstance(rho_in, Field):
rho_in = rho_in.center rho_in = rho_in.center
keep_rrs_symbolic = opt_params['keep_rrs_symbolic']
if u_in is not None: if u_in is not None:
density_rhs = sum(lb_method.pre_collision_pdf_symbols) if rho_in is None else rho_in density_rhs = sum(lb_method.pre_collision_pdf_symbols) if rho_in is None else rho_in
eqs = [Assignment(cqc.zeroth_order_moment_symbol, density_rhs)] eqs = [Assignment(cqc.zeroth_order_moment_symbol, density_rhs)]
eqs += [Assignment(u_sym, u_in[i]) for i, u_sym in enumerate(cqc.first_order_moment_symbols)] eqs += [Assignment(u_sym, u_in[i]) for i, u_sym in enumerate(cqc.first_order_moment_symbols)]
eqs = AssignmentCollection(eqs, []) eqs = AssignmentCollection(eqs, [])
collision_rule = lb_method.get_collision_rule(conserved_quantity_equations=eqs) collision_rule = lb_method.get_collision_rule(conserved_quantity_equations=eqs,
keep_rrs_symbolic=keep_rrs_symbolic)
elif u_in is None and rho_in is not None: elif u_in is None and rho_in is not None:
raise ValueError("When setting 'density_input' parameter, 'velocity_input' has to be specified as well.") raise ValueError("When setting 'density_input' parameter, 'velocity_input' has to be specified as well.")
else: else:
collision_rule = lb_method.get_collision_rule() collision_rule = lb_method.get_collision_rule(keep_rrs_symbolic=keep_rrs_symbolic)
if opt_params['simplification'] == 'auto':
simplification = create_simplification_strategy(lb_method, split_inner_loop=split_inner_loop)
else:
simplification = opt_params['simplification']
collision_rule = simplification(collision_rule) collision_rule = simplification(collision_rule)
if params['fluctuating']: if params['fluctuating']:
...@@ -353,7 +358,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs): ...@@ -353,7 +358,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
collision_rule = cse_in_opposing_directions(collision_rule) collision_rule = cse_in_opposing_directions(collision_rule)
if cse_global: if cse_global:
from pystencils.simp import sympy_cse
collision_rule = sympy_cse(collision_rule) collision_rule = sympy_cse(collision_rule)
if params['output'] and params['kernel_type'] == 'stream_pull_collide': if params['output'] and params['kernel_type'] == 'stream_pull_collide':
...@@ -550,6 +554,8 @@ def update_with_default_parameters(params, opt_params=None, fail_on_unknown_para ...@@ -550,6 +554,8 @@ def update_with_default_parameters(params, opt_params=None, fail_on_unknown_para
default_optimization_description = { default_optimization_description = {
'cse_pdfs': False, 'cse_pdfs': False,
'cse_global': False, 'cse_global': False,
'simplification': 'auto',
'keep_rrs_symbolic': True,
'split': False, 'split': False,
'field_size': None, 'field_size': None,
......
...@@ -88,9 +88,9 @@ class MomentBasedLbMethod(AbstractLbMethod): ...@@ -88,9 +88,9 @@ class MomentBasedLbMethod(AbstractLbMethod):
equilibrium = self.get_equilibrium() equilibrium = self.get_equilibrium()
return sp.Matrix([eq.rhs for eq in equilibrium.main_assignments]) return sp.Matrix([eq.rhs for eq in equilibrium.main_assignments])
def get_collision_rule(self, conserved_quantity_equations=None): def get_collision_rule(self, conserved_quantity_equations=None, keep_rrs_symbolic=True):
d = sp.diag(*self.relaxation_rates) d = sp.diag(*self.relaxation_rates)
relaxation_rate_sub_expressions, d = self._generate_relaxation_matrix(d) relaxation_rate_sub_expressions, d = self._generate_relaxation_matrix(d, keep_rrs_symbolic)
ac = self._collision_rule_with_relaxation_matrix(d, relaxation_rate_sub_expressions, ac = self._collision_rule_with_relaxation_matrix(d, relaxation_rate_sub_expressions,
True, conserved_quantity_equations) True, conserved_quantity_equations)
return ac return ac
...@@ -216,16 +216,15 @@ class MomentBasedLbMethod(AbstractLbMethod): ...@@ -216,16 +216,15 @@ class MomentBasedLbMethod(AbstractLbMethod):
simplification_hints) simplification_hints)
@staticmethod @staticmethod
def _generate_relaxation_matrix(relaxation_matrix): def _generate_relaxation_matrix(relaxation_matrix, keep_rr_symbolic):
""" """
For SRT and TRT the equations can be easier simplified if the relaxation times are symbols, not numbers. For SRT and TRT the equations can be easier simplified if the relaxation times are symbols, not numbers.
This function replaces the numbers in the relaxation matrix with symbols in this case, and returns also This function replaces the numbers in the relaxation matrix with symbols in this case, and returns also
the subexpressions, that assign the number to the newly introduced symbol the subexpressions, that assign the number to the newly introduced symbol
""" """
rr = [relaxation_matrix[i, i] for i in range(relaxation_matrix.rows)] rr = [relaxation_matrix[i, i] for i in range(relaxation_matrix.rows)]
unique_relaxation_rates = set(rr) if keep_rr_symbolic <= 2:
if len(unique_relaxation_rates) <= 2: unique_relaxation_rates = set(rr)
# special handling for SRT and TRT
subexpressions = {} subexpressions = {}
for rt in unique_relaxation_rates: for rt in unique_relaxation_rates:
rt = sp.sympify(rt) rt = sp.sympify(rt)
......
...@@ -2,42 +2,39 @@ import sympy as sp ...@@ -2,42 +2,39 @@ import sympy as sp
from lbmpy.innerloopsplit import create_lbm_split_groups from lbmpy.innerloopsplit import create_lbm_split_groups
from lbmpy.methods.cumulantbased import CumulantBasedLbMethod from lbmpy.methods.cumulantbased import CumulantBasedLbMethod
from lbmpy.methods.momentbased import MomentBasedLbMethod
from lbmpy.methods.momentbasedsimplifications import (
factor_density_after_factoring_relaxation_times, factor_relaxation_rates,
replace_common_quadratic_and_constant_term, replace_density_and_velocity, replace_second_order_velocity_products)
from pystencils.simp import ( from pystencils.simp import (
add_subexpressions_for_divisions, apply_to_all_assignments, SimplificationStrategy, add_subexpressions_for_divisions, apply_to_all_assignments,
subexpression_substitution_in_main_assignments, sympy_cse) subexpression_substitution_in_main_assignments)
def create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False, split_inner_loop=False): def create_simplification_strategy(lb_method, split_inner_loop=False):
from pystencils.simp import SimplificationStrategy
from lbmpy.methods.momentbased import MomentBasedLbMethod
from lbmpy.methods.momentbasedsimplifications import replace_second_order_velocity_products, \
factor_density_after_factoring_relaxation_times, factor_relaxation_rates, cse_in_opposing_directions, \
replace_common_quadratic_and_constant_term, replace_density_and_velocity
s = SimplificationStrategy() s = SimplificationStrategy()
expand = apply_to_all_assignments(sp.expand) expand = apply_to_all_assignments(sp.expand)
if isinstance(lb_method, MomentBasedLbMethod): if isinstance(lb_method, MomentBasedLbMethod):
s.add(expand) if len(set(lb_method.relaxation_rates)) <= 2:
s.add(replace_second_order_velocity_products) s.add(expand)
s.add(expand) s.add(replace_second_order_velocity_products)
s.add(factor_relaxation_rates) s.add(expand)
s.add(replace_density_and_velocity) s.add(factor_relaxation_rates)
s.add(replace_common_quadratic_and_constant_term) s.add(replace_density_and_velocity)
s.add(factor_density_after_factoring_relaxation_times) s.add(replace_common_quadratic_and_constant_term)
s.add(subexpression_substitution_in_main_assignments) s.add(factor_density_after_factoring_relaxation_times)
if split_inner_loop: s.add(subexpression_substitution_in_main_assignments)
s.add(create_lbm_split_groups) if split_inner_loop:
s.add(create_lbm_split_groups)
s.add(add_subexpressions_for_divisions)
else:
s.add(subexpression_substitution_in_main_assignments)
if split_inner_loop:
s.add(create_lbm_split_groups)
elif isinstance(lb_method, CumulantBasedLbMethod): elif isinstance(lb_method, CumulantBasedLbMethod):
s.add(expand) s.add(expand)
s.add(factor_relaxation_rates) s.add(factor_relaxation_rates)
s.add(add_subexpressions_for_divisions)
s.add(add_subexpressions_for_divisions)
if cse_pdfs:
s.add(cse_in_opposing_directions)
if cse_global:
s.add(sympy_cse)
return s return s
...@@ -6,14 +6,15 @@ import sympy as sp ...@@ -6,14 +6,15 @@ import sympy as sp
from lbmpy.forcemodels import Guo from lbmpy.forcemodels import Guo
from lbmpy.methods import create_srt, create_trt, create_trt_with_magic_number from lbmpy.methods import create_srt, create_trt, create_trt_with_magic_number
from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
from lbmpy.simplificationfactory import create_simplification_strategy from lbmpy.simplificationfactory import create_simplification_strategy
from lbmpy.stencils import get_stencil from lbmpy.stencils import get_stencil
def check_method(method, limits_default, limits_cse): def check_method(method, limits_default, limits_cse):
strategy = create_simplification_strategy(method, cse_pdfs=False) strategy = create_simplification_strategy(method)
strategy_with_cse = create_simplification_strategy(method, cse_pdfs=True) strategy_with_cse = create_simplification_strategy(method)
strategy_with_cse = cse_in_opposing_directions(strategy_with_cse)
collision_rule = method.get_collision_rule() collision_rule = method.get_collision_rule()
ops_default = strategy(collision_rule).operation_count ops_default = strategy(collision_rule).operation_count
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment