centeredcumulantmethod.py 20.7 KB
Newer Older
Frederik Hennig's avatar
Frederik Hennig committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
from pystencils.simp.simplifications import sympy_cse
import sympy as sp
from warnings import warn

from pystencils import Assignment, AssignmentCollection
from pystencils.simp.assignment_collection import SymbolGen
from pystencils.stencil import have_same_entries
from pystencils.cache import disk_cache

from lbmpy.stencils import get_stencil
from lbmpy.methods.abstractlbmethod import AbstractLbMethod, LbmCollisionRule, RelaxationInfo
from lbmpy.methods.conservedquantitycomputation import AbstractConservedQuantityComputation

from lbmpy.moments import (
    moments_up_to_order, get_order,
    monomial_to_polynomial_transformation_matrix,
    moment_sort_key, exponent_to_polynomial_representation, extract_monomials, MOMENT_SYMBOLS)

#   Local Imports

from lbmpy.methods.centeredcumulant.centered_cumulants import (
    statistical_quantity_symbol, exponent_tuple_sort_key)

from lbmpy.methods.centeredcumulant.cumulant_transform import (
    PRE_COLLISION_CUMULANT, POST_COLLISION_CUMULANT,
    CentralMomentsToCumulantsByGeneratingFunc)

from lbmpy.methods.momentbased.moment_transforms import (
    PRE_COLLISION_CENTRAL_MOMENT, POST_COLLISION_CENTRAL_MOMENT,
    FastCentralMomentTransform)

from lbmpy.methods.centeredcumulant.simplification import insert_aliases, insert_zeros

from lbmpy.methods.centeredcumulant.force_model import CenteredCumulantForceModel
from lbmpy.methods.centeredcumulant.galilean_correction import (
    contains_corrected_polynomials,
    add_galilean_correction,
    get_galilean_correction_terms)


#   ============================ Cached Transformations ================================================================

@disk_cache
def cached_forward_transform(transform_obj, *args, **kwargs):
    return transform_obj.forward_transform(*args, **kwargs)


@disk_cache
def cached_backward_transform(transform_obj, *args, **kwargs):
    return transform_obj.backward_transform(*args, **kwargs)


#   ============================ Lower Order Central Moment Collision ==================================================


@disk_cache
def relax_lower_order_central_moments(moment_indices, pre_collision_values,
                                      relaxation_rates, equilibrium_values,
                                      post_collision_base=POST_COLLISION_CENTRAL_MOMENT):

    post_collision_symbols = [statistical_quantity_symbol(post_collision_base, i) for i in moment_indices]
    equilibrium_vec = sp.Matrix(equilibrium_values)
    moment_vec = sp.Matrix(pre_collision_values)
    relaxation_matrix = sp.diag(*relaxation_rates)
    moment_vec = moment_vec + relaxation_matrix * (equilibrium_vec - moment_vec)
    main_assignments = [Assignment(s, eq) for s, eq in zip(post_collision_symbols, moment_vec)]

    return AssignmentCollection(main_assignments)


#   ============================ Polynomial Cumulant Collision =========================================================

@disk_cache
def relax_polynomial_cumulants(monomial_exponents, polynomials, relaxation_rates, equilibrium_values,
                               pre_simplification,
                               galilean_correction_terms=None,
                               pre_collision_base=PRE_COLLISION_CUMULANT,
                               post_collision_base=POST_COLLISION_CUMULANT,
                               subexpression_base='sub_col'):

    mon_to_poly_matrix = monomial_to_polynomial_transformation_matrix(monomial_exponents, polynomials)
    mon_vec = sp.Matrix([statistical_quantity_symbol(pre_collision_base, exp) for exp in monomial_exponents])
    equilibrium_vec = sp.Matrix(equilibrium_values)
    relaxation_matrix = sp.diag(*relaxation_rates)

    subexpressions = []

    poly_vec = mon_to_poly_matrix * mon_vec
    relaxed_polys = poly_vec + relaxation_matrix * (equilibrium_vec - poly_vec)

    if galilean_correction_terms is not None:
        relaxed_polys = add_galilean_correction(relaxed_polys, polynomials, galilean_correction_terms)
        subexpressions = galilean_correction_terms.all_assignments

    relaxed_monos = mon_to_poly_matrix.inv() * relaxed_polys

    main_assignments = [Assignment(statistical_quantity_symbol(post_collision_base, exp), v)
                        for exp, v in zip(monomial_exponents, relaxed_monos)]

    symbol_gen = SymbolGen(subexpression_base)
    ac = AssignmentCollection(
        main_assignments, subexpressions=subexpressions, subexpression_symbol_generator=symbol_gen)
    if pre_simplification == 'default_with_cse':
        ac = sympy_cse(ac)
    return ac


#   =============================== LB Method Implementation ===========================================================

class CenteredCumulantBasedLbMethod(AbstractLbMethod):
    """
        This class implements cumulant-based lattice boltzmann methods which relax all the non-conserved quantities
        as either monomial or polynomial cumulants. It is mostly inspired by the work presented in :cite:`geier2015`.

        Conserved quantities are relaxed in central moment space. This method supports an implicit forcing scheme
        through :class:`lbmpy.methods.centeredcumulant.CenteredCumulantForceModel` where forces are applied by
        shifting the central-moment frame of reference by :math:`F/2` and then relaxing the first-order central
        moments with a relaxation rate of two. This corresponds to the change-of-sign described in the paper.
        Classical forcing schemes can still be applied.

        The galilean correction described in :cite:`geier2015` is also available for the D3Q27 lattice.

        This method is implemented modularily as the transformation from populations to central moments to cumulants
        is governed by subclasses of :class:`lbmpy.methods.momentbased.moment_transforms.AbstractMomentTransform`
        which can be specified by constructor argument. This allows the selection of the most efficient transformation
        for a given setup.
    """

    def __init__(self, stencil, cumulant_to_relaxation_info_dict, conserved_quantity_computation, force_model=None,
                 galilean_correction=False,
                 central_moment_transform_class=FastCentralMomentTransform,
                 cumulant_transform_class=CentralMomentsToCumulantsByGeneratingFunc):
        assert isinstance(conserved_quantity_computation,
                          AbstractConservedQuantityComputation)
        super(CenteredCumulantBasedLbMethod, self).__init__(stencil)

        for m in moments_up_to_order(1, dim=self.dim):
            if exponent_to_polynomial_representation(m) not in cumulant_to_relaxation_info_dict.keys():
                raise ValueError(f'No relaxation info given for conserved cumulant {m}!')

        self._cumulant_to_relaxation_info_dict = cumulant_to_relaxation_info_dict
        self._conserved_quantity_computation = conserved_quantity_computation
        self._force_model = force_model
        self._weights = None
        self._galilean_correction = galilean_correction

        if galilean_correction:
            if not have_same_entries(stencil, get_stencil("D3Q27")):
                raise ValueError("Galilean Correction only available for D3Q27 stencil")

            if not contains_corrected_polynomials(cumulant_to_relaxation_info_dict):
                raise ValueError("For the galilean correction, all three polynomial cumulants"
                                 "(x^2 - y^2), (x^2 - z^2) and (x^2 + y^2 + z^2) must be present!")

        self._cumulant_transform_class = cumulant_transform_class
        self._central_moment_transform_class = central_moment_transform_class

        self.force_model_rr_override = False
        if isinstance(self._force_model, CenteredCumulantForceModel) and \
                self._force_model.override_momentum_relaxation_rate is not None:
            self.set_first_moment_relaxation_rate(self._force_model.override_momentum_relaxation_rate)
            self.force_model_rr_override = True

    @property
    def force_model(self):
        return self._force_model

    @property
    def relaxation_info_dict(self):
        return self._cumulant_to_relaxation_info_dict

    @property
    def zeroth_order_equilibrium_moment_symbol(self, ):
        return self._conserved_quantity_computation.zeroth_order_moment_symbol

    @property
    def first_order_equilibrium_moment_symbols(self, ):
        return self._conserved_quantity_computation.first_order_moment_symbols

    def set_zeroth_moment_relaxation_rate(self, relaxation_rate):
        e = sp.Rational(1, 1)
        prev_entry = self._cumulant_to_relaxation_info_dict[e]
        new_entry = RelaxationInfo(prev_entry[0], relaxation_rate)
        self._cumulant_to_relaxation_info_dict[e] = new_entry

    def set_first_moment_relaxation_rate(self, relaxation_rate):
        if self.force_model_rr_override:
            warn("Overwriting first-order relaxation rates governed by CenteredCumulantForceModel "
                 "might break your forcing scheme.")
        for e in MOMENT_SYMBOLS[:self.dim]:
            assert e in self._cumulant_to_relaxation_info_dict, \
                "First cumulants are not relaxed separately by this method"
        for e in MOMENT_SYMBOLS[:self.dim]:
            prev_entry = self._cumulant_to_relaxation_info_dict[e]
            new_entry = RelaxationInfo(prev_entry[0], relaxation_rate)
            self._cumulant_to_relaxation_info_dict[e] = new_entry

    def set_conserved_moments_relaxation_rate(self, relaxation_rate):
        self.set_zeroth_moment_relaxation_rate(relaxation_rate)
        self.set_first_moment_relaxation_rate(relaxation_rate)

    def set_force_model(self, force_model):
        self._force_model = force_model

    @property
    def cumulants(self):
        return sorted(self._cumulant_to_relaxation_info_dict.keys(), key=moment_sort_key)

    @property
    def cumulant_equilibrium_values(self):
        return tuple([e.equilibrium_value for e in self._cumulant_to_relaxation_info_dict.values()])

    @property
    def relaxation_rates(self):
        return tuple([e.relaxation_rate for e in self._cumulant_to_relaxation_info_dict.values()])

    def _repr_html_(self):
        table = """
        <table style="border:none; width: 100%">
            <tr {nb}>
                <th {nb} >Central Moment / Cumulant</th>
                <th {nb} >Eq. Value </th>
                <th {nb} >Relaxation Rate</th>
            </tr>
            {content}
        </table>
        """
        content = ""
        for cumulant, (eq_value, rr) in self._cumulant_to_relaxation_info_dict.items():
            vals = {
                'rr': f"${sp.latex(rr)}$",
                'cumulant': f"${sp.latex(cumulant)}$",
                'eq_value': f"${sp.latex(eq_value)}$",
                'nb': 'style="border:none"',
            }
            order = get_order(cumulant)
            if order <= 1:
                vals['cumulant'] += ' (central moment)'
                if order == 1 and self.force_model_rr_override:
                    vals['rr'] += ' (overridden by force model)'
            content += """<tr {nb}>
                            <td {nb}>{cumulant}</td>
                            <td {nb}>{eq_value}</td>
                            <td {nb}>{rr}</td>
                         </tr>\n""".format(**vals)
        return table.format(content=content, nb='style="border:none"')

    #   ----------------------- Overridden Abstract Members --------------------------

    @property
    def conserved_quantity_computation(self):
        """Returns an instance of class :class:`lbmpy.methods.AbstractConservedQuantityComputation`"""
        return self._conserved_quantity_computation

    @property
    def weights(self):
        """Returns a sequence of weights, one for each lattice direction"""
        if self._weights is None:
            self._weights = self._compute_weights()
        return self._weights

    def override_weights(self, weights):
        assert len(weights) == len(self.stencil)
        self._weights = weights

    def get_equilibrium(self, conserved_quantity_equations=None, subexpressions=False, pre_simplification=False,
                        keep_cqc_subexpressions=True):
        """Returns equation collection, to compute equilibrium values.
        The equations have the post collision symbols as left hand sides and are
        functions of the conserved quantities

        Args:
            conserved_quantity_equations: equations to compute conserved quantities.
            subexpressions: if set to false all subexpressions of the equilibrium assignments are plugged
                            into the main assignments
            pre_simplification: with or without pre_simplifications for the calculation of the collision
            keep_cqc_subexpressions: if equilibrium is returned without subexpressions keep_cqc_subexpressions
                                     determines if also subexpressions to calculate conserved quantities should be
                                     plugged into the main assignments
        """
        r_info_dict = {c: RelaxationInfo(info.equilibrium_value, 1)
                       for c, info in self._cumulant_to_relaxation_info_dict.items()}
        ac = self._centered_cumulant_collision_rule(
            r_info_dict, conserved_quantity_equations, pre_simplification, include_galilean_correction=False)
        if not subexpressions:
            if keep_cqc_subexpressions:
                bs = self._bound_symbols_cqc(conserved_quantity_equations)
                return ac.new_without_subexpressions(subexpressions_to_keep=bs)
            else:
                return ac.new_without_subexpressions()
        else:
            return ac

    def get_equilibrium_terms(self):
        equilibrium = self.get_equilibrium()
        return sp.Matrix([eq.rhs for eq in equilibrium.main_assignments])

    def get_collision_rule(self, conserved_quantity_equations=None, pre_simplification=False):
        """Returns an LbmCollisionRule i.e. an equation collection with a reference to the method.
        This collision rule defines the collision operator."""
        return self._centered_cumulant_collision_rule(
            self._cumulant_to_relaxation_info_dict, conserved_quantity_equations, pre_simplification, True)

    #   ------------------------------- Internals --------------------------------------------

    def _bound_symbols_cqc(self, conserved_quantity_equations=None):
        f = self.pre_collision_pdf_symbols
        cqe = conserved_quantity_equations

        if cqe is None:
            cqe = self._conserved_quantity_computation.equilibrium_input_equations_from_pdfs(f)

        return cqe.bound_symbols

    def _compute_weights(self):
        defaults = self._conserved_quantity_computation.default_values
        cqe = AssignmentCollection([Assignment(s, e) for s, e in defaults.items()])
        eq_ac = self.get_equilibrium(cqe, subexpressions=False, keep_cqc_subexpressions=False)

        weights = []
        for eq in eq_ac.main_assignments:
            value = eq.rhs.expand()
            assert len(value.atoms(sp.Symbol)) == 0, "Failed to compute weights"
            weights.append(value)
        return weights

    def _centered_cumulant_collision_rule(self, cumulant_to_relaxation_info_dict,
                                          conserved_quantity_equations=None,
                                          pre_simplification=False,
                                          include_force_terms=False,
                                          include_galilean_correction=True):
        stencil = self.stencil
        f = self.pre_collision_pdf_symbols
        density = self.zeroth_order_equilibrium_moment_symbol
        velocity = self.first_order_equilibrium_moment_symbols
        cqe = conserved_quantity_equations

        if cqe is None:
            cqe = self._conserved_quantity_computation.equilibrium_input_equations_from_pdfs(f)

        #   1) Extract Monomial Cumulants for the higher-order polynomials
        polynomial_cumulants = cumulant_to_relaxation_info_dict.keys()
        polynomial_cumulants = sorted(list(polynomial_cumulants), key=moment_sort_key)
        higher_order_polynomials = [p for p in polynomial_cumulants if get_order(p) > 1]
        monomial_cumulants = sorted(list(extract_monomials(
            higher_order_polynomials, dim=self.dim)), key=exponent_tuple_sort_key)

        #   2) Get Forward and Backward Transformations between central moment and cumulant space,
        #      and find required central moments
        k_to_c_transform = self._cumulant_transform_class(stencil, monomial_cumulants, density, velocity)
        k_to_c_eqs = cached_forward_transform(k_to_c_transform, simplification=pre_simplification)
        c_post_to_k_post_eqs = cached_backward_transform(
            k_to_c_transform, simplification=pre_simplification, omit_conserved_moments=True)
        central_moments = k_to_c_transform.required_central_moments
        assert len(central_moments) == len(stencil), 'Number of required central moments must match stencil size.'

        #   3) Get Forward Transformation from PDFs to central moments
        pdfs_to_k_transform = self._central_moment_transform_class(
            stencil, central_moments, density, velocity, conserved_quantity_equations=cqe)
        pdfs_to_k_eqs = cached_forward_transform(pdfs_to_k_transform, f, simplification=pre_simplification)

        #   4) Add relaxation rules for lower order moments
        lower_order_moments = moments_up_to_order(1, dim=self.dim)
        lower_order_moment_symbols = [statistical_quantity_symbol(PRE_COLLISION_CENTRAL_MOMENT, exp)
                                      for exp in lower_order_moments]

        lower_order_relaxation_infos = [cumulant_to_relaxation_info_dict[exponent_to_polynomial_representation(e)]
                                        for e in lower_order_moments]
        lower_order_relaxation_rates = [info.relaxation_rate for info in lower_order_relaxation_infos]
        lower_order_equilibrium = [info.equilibrium_value for info in lower_order_relaxation_infos]

        lower_order_moment_collision_eqs = relax_lower_order_central_moments(
            lower_order_moments, lower_order_moment_symbols,
            lower_order_relaxation_rates, lower_order_equilibrium)

        #   5) Add relaxation rules for higher-order, polynomial cumulants
        poly_relaxation_infos = [cumulant_to_relaxation_info_dict[c] for c in higher_order_polynomials]
        poly_relaxation_rates = [info.relaxation_rate for info in poly_relaxation_infos]
        poly_equilibrium = [info.equilibrium_value for info in poly_relaxation_infos]

        if self._galilean_correction and include_galilean_correction:
            galilean_correction_terms = get_galilean_correction_terms(
                cumulant_to_relaxation_info_dict, density, velocity)
        else:
            galilean_correction_terms = None

        cumulant_collision_eqs = relax_polynomial_cumulants(
            monomial_cumulants, higher_order_polynomials,
            poly_relaxation_rates, poly_equilibrium,
            pre_simplification,
            galilean_correction_terms=galilean_correction_terms)

        #   6) Get backward transformation from central moments to PDFs
        d = self.post_collision_pdf_symbols
        k_post_to_pdfs_eqs = cached_backward_transform(pdfs_to_k_transform, d, simplification=pre_simplification)

        #   7) That's all. Now, put it all together.
        all_acs = [] if pdfs_to_k_transform.absorbs_conserved_quantity_equations else [cqe]
        all_acs += [pdfs_to_k_eqs, k_to_c_eqs, lower_order_moment_collision_eqs,
                    cumulant_collision_eqs, c_post_to_k_post_eqs]
        subexpressions = [ac.all_assignments for ac in all_acs]
        subexpressions += k_post_to_pdfs_eqs.subexpressions
        main_assignments = k_post_to_pdfs_eqs.main_assignments

        #   8) Maybe add forcing terms if CenteredCumulantForceModel was not used
        if self._force_model is not None and \
                not isinstance(self._force_model, CenteredCumulantForceModel) and include_force_terms:
            force_model_terms = self._force_model(self)
            force_term_symbols = sp.symbols("forceTerm_:%d" % (len(force_model_terms, )))
            force_subexpressions = [Assignment(sym, force_model_term)
                                    for sym, force_model_term in zip(force_term_symbols, force_model_terms)]
            subexpressions += force_subexpressions
            main_assignments = [Assignment(eq.lhs, eq.rhs + force_term_symbol)
                                for eq, force_term_symbol in zip(main_assignments, force_term_symbols)]

        #   9) Clean up the subexpression tree
        ac = AssignmentCollection(main_assignments, subexpressions)

        if pre_simplification and pre_simplification != 'none':
            ac = insert_aliases(insert_zeros(ac))
            ac = ac.new_without_unused_subexpressions()

        #   Aaaaaand we're done.
        return LbmCollisionRule(self, ac.main_assignments, ac.subexpressions)