Commit 6f4189b0 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Unified moment transform API.

parent f68761dc
......@@ -408,7 +408,7 @@ def create_lb_method(**params):
no_force_model = 'force_model' not in params or params['force_model'] == 'none' or params['force_model'] is None
if not force_is_zero and no_force_model:
params['force_model'] = 'cumulant' if method_name.lower().endswith('cumulant') else 'schiller'
params['force_model'] = 'cumulant' if method_name.lower().endswith('cumulant') else 'guo'
if 'force_model' in params:
force_model = force_model_from_string(params['force_model'], params['force'][:dim])
......
......@@ -371,13 +371,13 @@ class CenteredCumulantBasedLbMethod(AbstractLbMethod):
k_to_c_eqs = cached_forward_transform(k_to_c_transform, simplification=pre_simplification)
c_post_to_k_post_eqs = cached_backward_transform(
k_to_c_transform, simplification=pre_simplification, omit_conserved_moments=True)
central_moments = exponents_to_polynomial_representations(k_to_c_transform.required_central_moments)
central_moments = k_to_c_transform.required_central_moments
assert len(central_moments) == len(stencil), 'Number of required central moments must match stencil size.'
# 3) Get Forward Transformation from PDFs to central moments
pdfs_to_k_transform = self._central_moment_transform_class(
stencil, central_moments, density, velocity, conserved_quantity_equations=cqe)
pdfs_to_k_eqs = cached_forward_transform(pdfs_to_k_transform, f, simplification=pre_simplification)
stencil, None, density, velocity, moment_exponents=central_moments, conserved_quantity_equations=cqe)
pdfs_to_k_eqs = cached_forward_transform(pdfs_to_k_transform, f, simplification=pre_simplification, return_monomials=True)
# 4) Add relaxation rules for lower order moments
lower_order_moments = moments_up_to_order(1, dim=self.dim)
......@@ -412,7 +412,7 @@ class CenteredCumulantBasedLbMethod(AbstractLbMethod):
# 6) Get backward transformation from central moments to PDFs
d = self.post_collision_pdf_symbols
k_post_to_pdfs_eqs = cached_backward_transform(pdfs_to_k_transform, d, simplification=pre_simplification)
k_post_to_pdfs_eqs = cached_backward_transform(pdfs_to_k_transform, d, simplification=pre_simplification, start_from_monomials=True)
# 7) That's all. Now, put it all together.
all_acs = [] if pdfs_to_k_transform.absorbs_conserved_quantity_equations else [cqe]
......
......@@ -13,13 +13,12 @@ from lbmpy.moments import (MOMENT_SYMBOLS, moment_matrix, set_up_shift_matrix,
statistical_quantity_symbol)
def relax_central_moments(moment_indices, pre_collision_values,
def relax_central_moments(pre_collision_symbols, post_collision_symbols,
relaxation_rates, equilibrium_values,
force_terms,
post_collision_base=POST_COLLISION_MONOMIAL_CENTRAL_MOMENT):
post_collision_symbols = [sp.Symbol(f'{post_collision_base}_{"".join(str(i) for i in m)}') for m in moment_indices]
equilibrium_vec = sp.Matrix(equilibrium_values)
moment_vec = sp.Matrix(pre_collision_values)
moment_vec = sp.Matrix(pre_collision_symbols)
relaxation_matrix = sp.diag(*relaxation_rates)
moment_vec = moment_vec + relaxation_matrix * (equilibrium_vec - moment_vec) + force_terms
main_assignments = [Assignment(s, eq) for s, eq in zip(post_collision_symbols, moment_vec)]
......@@ -268,11 +267,9 @@ class CentralMomentBasedLbMethod(AbstractLbMethod):
stencil, self.moments, density, velocity, conserved_quantity_equations=cqe)
pdfs_to_c_eqs = pdfs_to_c_transform.forward_transform(f, simplification=pre_simplification)
moments_as_exponents = pdfs_to_c_transform.moment_exponents
# 2) Collision
moment_symbols = [statistical_quantity_symbol(PRE_COLLISION_MONOMIAL_CENTRAL_MOMENT, exp)
for exp in moments_as_exponents]
k_pre = pdfs_to_c_transform.pre_collision_symbols
k_post = pdfs_to_c_transform.post_collision_symbols
relaxation_infos = [moment_to_relaxation_info_dict[m] for m in self.moments]
relaxation_rates = [info.relaxation_rate for info in relaxation_infos]
......@@ -283,9 +280,8 @@ class CentralMomentBasedLbMethod(AbstractLbMethod):
else:
force_model_terms = sp.Matrix([0] * len(stencil))
collision_eqs = relax_central_moments(moments_as_exponents, tuple(moment_symbols),
tuple(relaxation_rates), tuple(equilibrium_value),
force_terms=force_model_terms)
collision_eqs = relax_central_moments(k_pre, k_post, tuple(relaxation_rates),
tuple(equilibrium_value), force_terms=force_model_terms)
# 3) Get backward transformation from central moments to PDFs
post_collision_values = self.post_collision_pdf_symbols
......
......@@ -255,8 +255,8 @@ class MomentBasedLbMethod(AbstractLbMethod):
pdf_to_m_transform = self._moment_transform_class(self.stencil, moment_polynomials, rho, u,
conserved_quantity_equations=cqe)
m_pre = pdf_to_m_transform.pre_collision_moment_symbols
m_post = pdf_to_m_transform.post_collision_moment_symbols
m_pre = pdf_to_m_transform.pre_collision_symbols
m_post = pdf_to_m_transform.post_collision_symbols
pdf_to_m_eqs = pdf_to_m_transform.forward_transform(self.pre_collision_pdf_symbols,
simplification=pre_simplification)
......
from .abstractmomenttransform import (
PRE_COLLISION_MONOMIAL_RAW_MOMENT, POST_COLLISION_MONOMIAL_RAW_MOMENT,
PRE_COLLISION_MOMENT, POST_COLLISION_MOMENT,
PRE_COLLISION_RAW_MOMENT, POST_COLLISION_RAW_MOMENT,
PRE_COLLISION_MONOMIAL_CENTRAL_MOMENT, POST_COLLISION_MONOMIAL_CENTRAL_MOMENT
)
from .abstractmomenttransform import AbstractMomentTransform
from .momenttransforms import (
from .rawmomenttransforms import (
PdfsToMomentsByMatrixTransform, PdfsToMomentsByChimeraTransform
)
......@@ -23,6 +23,6 @@ __all__ = [
"PdfsToCentralMomentsByShiftMatrix",
"FastCentralMomentTransform",
"PRE_COLLISION_MONOMIAL_RAW_MOMENT", "POST_COLLISION_MONOMIAL_RAW_MOMENT",
"PRE_COLLISION_MOMENT", "POST_COLLISION_MOMENT",
"PRE_COLLISION_RAW_MOMENT", "POST_COLLISION_RAW_MOMENT",
"PRE_COLLISION_MONOMIAL_CENTRAL_MOMENT", "POST_COLLISION_MONOMIAL_CENTRAL_MOMENT"
]
from abc import abstractmethod
import sympy as sp
from pystencils.simp import (SimplificationStrategy, sympy_cse)
from lbmpy.moments import (
exponents_to_polynomial_representations, extract_monomials, exponent_tuple_sort_key)
from lbmpy.moments import statistical_quantity_symbol as sq_sym
PRE_COLLISION_MONOMIAL_RAW_MOMENT = 'm'
POST_COLLISION_MONOMIAL_RAW_MOMENT = 'm_post'
PRE_COLLISION_MOMENT = 'M'
POST_COLLISION_MOMENT = 'M_post'
PRE_COLLISION_RAW_MOMENT = 'M'
POST_COLLISION_RAW_MOMENT = 'M_post'
PRE_COLLISION_MONOMIAL_CENTRAL_MOMENT = 'kappa'
POST_COLLISION_MONOMIAL_CENTRAL_MOMENT = 'kappa_post'
PRE_COLLISION_CENTRAL_MOMENT = 'K'
POST_COLLISION_CENTRAL_MOMENT = 'K_post'
class AbstractMomentTransform:
r"""Abstract Base Class for classes providing transformations between moment spaces.
......@@ -69,7 +74,10 @@ class AbstractMomentTransform:
moment_exponents=None,
moment_polynomials=None,
conserved_quantity_equations=None,
**kwargs):
pre_collision_symbol_base=None,
post_collision_symbol_base=None,
pre_collision_monomial_symbol_base=None,
post_collision_monomial_symbol_base=None):
"""Abstract Base Class constructor.
Args:
......@@ -102,6 +110,38 @@ class AbstractMomentTransform:
self.equilibrium_density = equilibrium_density
self.equilibrium_velocity = equilibrium_velocity
self.base_pre = pre_collision_symbol_base
self.base_post = post_collision_symbol_base
self.mono_base_pre = pre_collision_monomial_symbol_base
self.mono_base_post = post_collision_monomial_symbol_base
@property
def pre_collision_symbols(self):
"""List of symbols corresponding to the pre-collision quantities
that will be the left-hand sides of assignments returned by :func:`forward_transform`."""
return sp.symbols(f'{self.base_pre}_:{self.q}')
@property
def post_collision_symbols(self):
"""List of symbols corresponding to the post-collision quantities
that are input to the right-hand sides of assignments returned by:func:`backward_transform`."""
return sp.symbols(f'{self.base_post}_:{self.q}')
@property
def pre_collision_monomial_symbols(self):
"""List of symbols corresponding to the pre-collision monomial quantities
that might exist as left-hand sides of subexpressions in the assignment collection
returned by :func:`forward_transform`."""
return tuple(sq_sym(self.mono_base_pre, e) for e in self.moment_exponents)
@property
def post_collision_monomial_symbols(self):
"""List of symbols corresponding to the post-collision monomial quantities
that might exist as left-hand sides of subexpressions in the assignment collection
returned by :func:`backward_transform`."""
return tuple(sq_sym(self.mono_base_post, e) for e in self.moment_exponents)
@abstractmethod
def forward_transform(self, *args, **kwargs):
"""Implemented in a subclass, will return the forward transform equations."""
......
......@@ -11,12 +11,37 @@ from lbmpy.moments import statistical_quantity_symbol as sq_sym
from .abstractmomenttransform import (
AbstractMomentTransform,
PRE_COLLISION_MOMENT, POST_COLLISION_MOMENT,
PRE_COLLISION_RAW_MOMENT, POST_COLLISION_RAW_MOMENT,
PRE_COLLISION_MONOMIAL_RAW_MOMENT, POST_COLLISION_MONOMIAL_RAW_MOMENT
)
class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
class AbstractRawMomentTransform(AbstractMomentTransform):
"""Abstract base class for all transformations between population space
and raw-moment space."""
def __init__(self, stencil, moment_polynomials,
equilibrium_density,
equilibrium_velocity,
pre_collision_symbol_base=PRE_COLLISION_RAW_MOMENT,
post_collision_symbol_base=POST_COLLISION_RAW_MOMENT,
pre_collision_monomial_symbol_base=PRE_COLLISION_MONOMIAL_RAW_MOMENT,
post_collision_monomial_symbol_base=POST_COLLISION_MONOMIAL_RAW_MOMENT,
**kwargs):
super(AbstractRawMomentTransform, self).__init__(
stencil, equilibrium_density, equilibrium_velocity,
moment_polynomials=moment_polynomials,
pre_collision_symbol_base=pre_collision_symbol_base,
post_collision_symbol_base=post_collision_symbol_base,
pre_collision_monomial_symbol_base=pre_collision_monomial_symbol_base,
post_collision_monomial_symbol_base=post_collision_monomial_symbol_base,
**kwargs
)
# end class AbstractRawMomentTransform
class PdfsToMomentsByMatrixTransform(AbstractRawMomentTransform):
"""Transform between populations and moment space spanned by a polynomial
basis, using matrix-vector multiplication."""
......@@ -24,19 +49,12 @@ class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
equilibrium_density,
equilibrium_velocity,
conserved_quantity_equations=None,
pre_collision_moment_base=PRE_COLLISION_MOMENT,
post_collision_moment_base=POST_COLLISION_MOMENT,
**kwargs):
assert len(moment_polynomials) == len(stencil), 'Number of moments must match stencil'
super(PdfsToMomentsByMatrixTransform, self).__init__(
stencil, equilibrium_density, equilibrium_velocity,
conserved_quantity_equations=conserved_quantity_equations,
moment_polynomials=moment_polynomials,
**kwargs)
self.m_pre = pre_collision_moment_base
self.m_post = post_collision_moment_base
stencil, moment_polynomials, equilibrium_density, equilibrium_velocity,
conserved_quantity_equations=conserved_quantity_equations, **kwargs)
self.moment_matrix = moment_matrix(self.moment_polynomials, stencil)
self.inv_moment_matrix = self.moment_matrix.inv()
......@@ -45,19 +63,8 @@ class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
def absorbs_conserved_quantity_equations(self):
return False
@property
def pre_collision_moment_symbols(self):
"""List of symbols corresponding to the pre-collision moments
that will be the left-hand sides of assignments returned by :func:`forward_transform`."""
return sp.symbols(f'{self.m_pre}_:{self.q}')
@property
def post_collision_moment_symbols(self):
"""List of symbols corresponding to the post-collision moments
that are input to the right-hand sides of assignments returned by:func:`backward_transform`."""
return sp.symbols(f'{self.m_post}_:{self.q}')
def forward_transform(self, pdf_symbols, simplification=True, subexpression_base='sub_f_to_M'):
def forward_transform(self, pdf_symbols, simplification=True, subexpression_base='sub_f_to_M',
return_monomials=False):
r"""Returns an assignment collection containing equations for pre-collision polynomial
moments, expressed in terms of the pre-collision populations by matrix-multiplication.
......@@ -72,8 +79,14 @@ class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
"""
simplification = self._get_simp_strategy(simplification, 'forward')
f_to_m_vec = self.moment_matrix * sp.Matrix(pdf_symbols)
pre_collision_moments = self.pre_collision_moment_symbols
if return_monomials:
mm = moment_matrix(self.moment_exponents, self.stencil)
pre_collision_moments = self.pre_collision_monomial_symbols
else:
mm = self.moment_matrix
pre_collision_moments = self.pre_collision_symbols
f_to_m_vec = mm * sp.Matrix(pdf_symbols)
main_assignments = [Assignment(m, eq) for m, eq in zip(pre_collision_moments, f_to_m_vec)]
symbol_gen = SymbolGen(symbol=subexpression_base)
......@@ -83,7 +96,8 @@ class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
ac = simplification.apply(ac)
return ac
def backward_transform(self, pdf_symbols, simplification=True, subexpression_base='sub_k_to_f'):
def backward_transform(self, pdf_symbols, simplification=True, subexpression_base='sub_k_to_f',
start_from_monomials=False):
r"""Returns an assignment collection containing equations for post-collision populations,
expressed in terms of the post-collision polynomial moments by matrix-multiplication.
......@@ -112,12 +126,19 @@ class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
"""
simplification = self._get_simp_strategy(simplification, 'backward')
post_collision_moments = self.post_collision_moment_symbols
m_to_f_vec = self.inv_moment_matrix * sp.Matrix(post_collision_moments)
if start_from_monomials:
mm_inv = moment_matrix(self.moment_exponents, self.stencil).inv()
post_collision_moments = self.post_collision_monomial_symbols
else:
mm_inv = self.inv_moment_matrix
post_collision_moments = self.post_collision_symbols
m_to_f_vec = mm_inv * sp.Matrix(post_collision_moments)
main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, m_to_f_vec)]
symbol_gen = SymbolGen(subexpression_base)
ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen)
ac.add_simplification_hint('stencil', self.stencil)
ac.add_simplification_hint('post_collision_pdf_symbols', pdf_symbols)
if simplification:
......@@ -146,7 +167,7 @@ class PdfsToMomentsByMatrixTransform(AbstractMomentTransform):
# end class PdfsToMomentsByMatrixTransform
class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
class PdfsToMomentsByChimeraTransform(AbstractRawMomentTransform):
"""Transform between populations and moment space spanned by a polynomial
basis, using the raw-moment chimera transform in the forward direction and
matrix-vector multiplication in the backward direction."""
......@@ -155,10 +176,6 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
equilibrium_density,
equilibrium_velocity,
conserved_quantity_equations=None,
pre_collision_moment_base=PRE_COLLISION_MOMENT,
post_collision_moment_base=POST_COLLISION_MOMENT,
pre_collision_raw_moment_base=PRE_COLLISION_MONOMIAL_RAW_MOMENT,
post_collision_raw_moment_base=POST_COLLISION_MONOMIAL_RAW_MOMENT,
**kwargs):
if moment_polynomials:
......@@ -166,18 +183,11 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
moment_polynomials = non_aliased_polynomial_moments(moment_polynomials, stencil)
super(PdfsToMomentsByChimeraTransform, self).__init__(
stencil, equilibrium_density, equilibrium_velocity,
conserved_quantity_equations=conserved_quantity_equations,
moment_polynomials=moment_polynomials,
**kwargs)
stencil, moment_polynomials, equilibrium_density, equilibrium_velocity,
conserved_quantity_equations=conserved_quantity_equations, **kwargs)
assert len(self.moment_polynomials) == len(stencil), 'Number of moments must match stencil'
self.m_pre = pre_collision_moment_base
self.m_post = post_collision_moment_base
self.rm_pre = pre_collision_raw_moment_base
self.rm_post = post_collision_raw_moment_base
self.inv_moment_matrix = moment_matrix(self.moment_exponents, self.stencil).inv()
self.mono_to_poly_matrix = monomial_to_polynomial_transformation_matrix(self.moment_exponents,
self.moment_polynomials)
......@@ -187,32 +197,6 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
def absorbs_conserved_quantity_equations(self):
return True
@property
def pre_collision_moment_symbols(self):
"""List of symbols corresponding to the pre-collision moments
that will be the left-hand sides of assignments returned by :func:`forward_transform`."""
return sp.symbols(f'{self.m_pre}_:{self.q}')
@property
def post_collision_moment_symbols(self):
"""List of symbols corresponding to the post-collision moments
that are input to the right-hand sides of assignments returned by:func:`backward_transform`."""
return sp.symbols(f'{self.m_post}_:{self.q}')
@property
def pre_collision_raw_moment_symbols(self):
"""List of symbols corresponding to the pre-collision raw (monomial) moments
that exist as left-hand sides of subexpressions in the assignment collection
returned by :func:`forward_transform`."""
return tuple(sq_sym(self.rm_pre, e) for e in self.moment_exponents)
@property
def post_collision_raw_moment_symbols(self):
"""List of symbols corresponding to the post-collision raw (monomial) moments
that exist as left-hand sides of subexpressions in the assignment collection
returned by :func:`backward_transform`."""
return tuple(sq_sym(self.rm_post, e) for e in self.moment_exponents)
def get_cq_to_moment_symbols_dict(self, moment_symbol_base):
"""Returns a dictionary mapping the density and velocity symbols to the correspondig
zeroth- and first-order raw moment symbols"""
......@@ -232,7 +216,7 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
def forward_transform(self, pdf_symbols, simplification=True,
subexpression_base='sub_f_to_m',
return_raw_moments=False):
return_monomials=False):
r"""Returns an assignment collection containing equations for pre-collision polynomial
moments, expressed in terms of the pre-collision populations, using the raw moment chimera transform.
......@@ -267,15 +251,15 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
"""
simplification = self._get_simp_strategy(simplification, 'forward')
raw_moment_symbol_base = self.rm_pre
monomial_symbol_base = self.mono_base_pre
def _partial_kappa_symbol(fixed_directions, remaining_exponents):
fixed_str = '_'.join(str(direction) for direction in fixed_directions).replace('-', 'm')
exp_str = '_'.join(str(exp) for exp in remaining_exponents).replace('-', 'm')
return sp.Symbol(f"partial_{raw_moment_symbol_base}_{fixed_str}_e_{exp_str}")
return sp.Symbol(f"partial_{monomial_symbol_base}_{fixed_str}_e_{exp_str}")
partial_sums_dict = dict()
raw_moment_eqs = []
monomial_eqs = []
def collect_partial_sums(exponents, dimension=0, fixed_directions=tuple()):
if dimension == self.dim:
......@@ -293,8 +277,8 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
summation += next_partial * d ** exponents[dimension]
if dimension == 0:
lhs_symbol = sq_sym(raw_moment_symbol_base, exponents)
raw_moment_eqs.append(Assignment(lhs_symbol, summation))
lhs_symbol = sq_sym(monomial_symbol_base, exponents)
monomial_eqs.append(Assignment(lhs_symbol, summation))
else:
lhs_symbol = _partial_kappa_symbol(fixed_directions, exponents[dimension:])
partial_sums_dict[lhs_symbol] = summation
......@@ -307,17 +291,17 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
subexpressions = self.cqe.subexpressions.copy() if self.cqe is not None else []
subexpressions += [Assignment(lhs, rhs) for lhs, rhs in partial_sums_dict.items()]
if return_raw_moments:
main_assignments += raw_moment_eqs
if return_monomials:
main_assignments += monomial_eqs
else:
subexpressions += raw_moment_eqs
moment_eqs = self.mono_to_poly_matrix * sp.Matrix(self.pre_collision_raw_moment_symbols)
main_assignments += [Assignment(m, v) for m, v in zip(self.pre_collision_moment_symbols, moment_eqs)]
subexpressions += monomial_eqs
moment_eqs = self.mono_to_poly_matrix * sp.Matrix(self.pre_collision_monomial_symbols)
main_assignments += [Assignment(m, v) for m, v in zip(self.pre_collision_symbols, moment_eqs)]
symbol_gen = SymbolGen(subexpression_base)
ac = AssignmentCollection(main_assignments, subexpressions=subexpressions,
subexpression_symbol_generator=symbol_gen)
ac.add_simplification_hint('cq_symbols_to_moments', self.get_cq_to_moment_symbols_dict(raw_moment_symbol_base))
ac.add_simplification_hint('cq_symbols_to_moments', self.get_cq_to_moment_symbols_dict(monomial_symbol_base))
if simplification:
ac = simplification.apply(ac)
......@@ -325,7 +309,7 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
def backward_transform(self, pdf_symbols, simplification=True,
subexpression_base='sub_k_to_f',
start_from_raw_moments=False):
start_from_monomials=False):
r"""Returns an assignment collection containing equations for post-collision populations,
expressed in terms of the post-collision polynomial moments by matrix-multiplication.
......@@ -352,20 +336,21 @@ class PdfsToMomentsByChimeraTransform(AbstractMomentTransform):
pdf_symbols: List of symbols that represent the post-collision populations
simplification: Simplification specification. See :class:`AbstractMomentTransform`
subexpression_base: The base name used for any subexpressions of the transformation.
start_from_raw_moments: If set to True the equations are not converted to monomials
start_from_monomials: If ``True``, the generated equations expect monomial moment symbols
as input.
"""
simplification = self._get_simp_strategy(simplification, 'backward')
post_collision_moments = self.post_collision_moment_symbols
post_collision_raw_moments = self.post_collision_raw_moment_symbols
post_collision_moments = self.post_collision_symbols
post_collision_monomial_moments = self.post_collision_monomial_symbols
subexpressions = []
if not start_from_raw_moments:
if not start_from_monomials:
raw_moment_eqs = self.poly_to_mono_matrix * sp.Matrix(post_collision_moments)
subexpressions += [Assignment(rm, v) for rm, v in zip(post_collision_raw_moments, raw_moment_eqs)]
subexpressions += [Assignment(rm, v) for rm, v in zip(post_collision_monomial_moments, raw_moment_eqs)]
rm_to_f_vec = self.inv_moment_matrix * sp.Matrix(post_collision_raw_moments)
rm_to_f_vec = self.inv_moment_matrix * sp.Matrix(post_collision_monomial_moments)
main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, rm_to_f_vec)]
symbol_gen = SymbolGen(subexpression_base)
......
......@@ -34,11 +34,9 @@ def test_forward_transform(type, stencil):
assert shift_transform.moment_exponents == fast_transform.moment_exponents
if type == 'monomial' and not have_same_entries(stencil, get_stencil('D3Q15')):
assert matrix_transform.mono_to_poly_matrix == sp.eye(q)
assert fast_transform.mono_to_poly_matrix == sp.eye(q)
assert shift_transform.mono_to_poly_matrix == sp.eye(q)
else:
assert not matrix_transform.mono_to_poly_matrix == sp.eye(q)
assert not fast_transform.mono_to_poly_matrix == sp.eye(q)
assert not shift_transform.mono_to_poly_matrix == sp.eye(q)
......@@ -53,8 +51,9 @@ def test_forward_transform(type, stencil):
f_to_k_shift = shift_transform.forward_transform(pdfs, simplification=False)
f_to_k_shift = f_to_k_shift.new_without_subexpressions().main_assignments_dict
for e in moment_exponents:
moment_symbol = statistical_quantity_symbol(PRE_COLLISION_MONOMIAL_CENTRAL_MOMENT, e)
cm_symbols = matrix_transform.pre_collision_symbols
for moment_symbol in cm_symbols:
rhs_matrix = f_to_k_matrix[moment_symbol].expand()
rhs_fast = f_to_k_fast[moment_symbol].expand()
rhs_shift = f_to_k_shift[moment_symbol].expand()
......@@ -83,15 +82,6 @@ def test_backward_transform(type, stencil):
assert matrix_transform.moment_exponents == fast_transform.moment_exponents
assert shift_transform.moment_exponents == fast_transform.moment_exponents
if type == 'monomial' and not have_same_entries(stencil, get_stencil('D3Q15')):
assert matrix_transform.mono_to_poly_matrix == sp.eye(q)
assert fast_transform.mono_to_poly_matrix == sp.eye(q)
assert shift_transform.mono_to_poly_matrix == sp.eye(q)
else:
assert not matrix_transform.mono_to_poly_matrix == sp.eye(q)
assert not fast_transform.mono_to_poly_matrix == sp.eye(q)
assert not shift_transform.mono_to_poly_matrix == sp.eye(q)
k_to_f_matrix = matrix_transform.backward_transform(pdfs)
k_to_f_matrix = k_to_f_matrix.new_without_subexpressions().main_assignments_dict
......
......@@ -82,7 +82,7 @@ def get_fluctuating_lb(size=None, kT=None, omega_shear=None, omega_bulk=None, om
compressible=True,
weighted=True,
relaxation_rate_getter=rr_getter,
force_model=force_model_from_string('schiller', force_field.center_vector))
force_model=force_model_from_string('guo', force_field.center_vector))
collision_rule = create_lb_collision_rule(
method,
fluctuating={
......
......@@ -7,7 +7,7 @@ def test_gpu_block_size_limiting():
too_large = 2048*2048
opt = {'target': 'gpu', 'gpu_indexing_params': {'block_size': (too_large, too_large, too_large)}}
ast = create_lb_ast(method='cumulant', stencil='D3Q19', relaxation_rate=1.8, optimization=opt,
compressible=True, force_model='schiller')
compressible=True, force_model='guo')
limited_block_size = ast.indexing.call_parameters((1024, 1024, 1024))
kernel = ast.compile()
assert all(b < too_large for b in limited_block_size['block'])
......
import pytest
import sympy as sp
from lbmpy.stencils import get_stencil
from lbmpy.moments import get_default_moment_set_for_stencil
from lbmpy.moment_transforms import (
PdfsToMomentsByMatrixTransform, PdfsToMomentsByChimeraTransform,
PdfsToCentralMomentsByShiftMatrix, PdfsToCentralMomentsByMatrix, FastCentralMomentTransform
)
transforms = [
PdfsToMomentsByMatrixTransform, PdfsToMomentsByChimeraTransform,
PdfsToCentralMomentsByShiftMatrix, PdfsToCentralMomentsByMatrix, FastCentralMomentTransform
]
@pytest.mark.parametrize('stencil', ['D2Q9'])
@pytest.mark.parametrize('transform_class', transforms)
def test_monomial_equations(stencil, transform_class):
stencil = get_stencil(stencil)
rho = sp.symbols("rho")
u = sp.symbols(f"u_:{len(stencil[0])}")
moment_polynomials = get_default_moment_set_for_stencil(stencil)
transform = transform_class(stencil, moment_polynomials, rho, u)
pdfs = sp.symbols(f"f_:{len(stencil)}")
fw_eqs = transform.forward_transform(pdfs, return_monomials=True)
bw_eqs = transform.backward_transform(pdfs, start_from_monomials=True)
mono_symbols_pre = set(transform.pre_collision_monomial_symbols)
mono_symbols_post = set(transform.post_collision_monomial_symbols)
assert mono_symbols_pre <= set(fw_eqs.defined_symbols)
assert mono_symbols_post <= set(bw_eqs.free_symbols)