walberla_lbm_generation.py 16.1 KB
Newer Older
1
# import warnings
2
3
4
5
6
7
8
9
10
11
12
13
14
15

import numpy as np
import sympy as sp
from jinja2 import Environment, PackageLoader, StrictUndefined, Template
from sympy.tensor import IndexedBase

import pystencils as ps
from lbmpy.fieldaccess import CollideOnlyInplaceAccessor, StreamPullTwoFieldsAccessor
from lbmpy.relaxationrates import relaxation_rate_scaling
from lbmpy.stencils import get_stencil
from lbmpy.updatekernels import create_lbm_kernel, create_stream_pull_only_kernel
from pystencils import AssignmentCollection, create_kernel
from pystencils.astnodes import SympyAssignment
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, get_headers
Markus Holzer's avatar
Markus Holzer committed
16
from pystencils.data_types import TypedSymbol, type_all_numbers, cast_func
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from pystencils.field import Field
from pystencils.stencil import have_same_entries, offset_to_direction_string
from pystencils.sympyextensions import get_symmetric_part
from pystencils.transformations import add_types
from pystencils_walberla.codegen import KernelInfo, default_create_kernel_parameters
from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env

cpp_printer = CustomSympyPrinter()
REFINEMENT_SCALE_FACTOR = sp.Symbol("level_scale_factor")


def __lattice_model(generation_context, class_name, lb_method, stream_collide_ast, collide_ast, stream_ast,
                    refinement_scaling):
    stencil_name = get_stencil_name(lb_method.stencil)
    if not stencil_name:
        raise ValueError("lb_method uses a stencil that is not supported in waLBerla")
Markus Holzer's avatar
Markus Holzer committed
33
34
35


    communication_stencil_name = stencil_name if stencil_name != "D3Q15" else "D3Q27"
36
    is_float = not generation_context.double_accuracy
Markus Holzer's avatar
Markus Holzer committed
37
    dtype_string = "float32" if is_float else "float64"
38
39

    vel_symbols = lb_method.conserved_quantity_computation.first_order_moment_symbols
Markus Holzer's avatar
Markus Holzer committed
40
41
    rho_sym = sp.Symbol('rho')
    pdfs_sym = sp.symbols(f'f_:{len(lb_method.stencil)}')
42
    vel_arr_symbols = [IndexedBase(sp.Symbol('u'), shape=(1,))[i] for i in range(len(vel_symbols))]
Markus Holzer's avatar
Markus Holzer committed
43
    momentum_density_symbols = sp.symbols(f'md_:{len(vel_symbols)}')
44
45
46

    equilibrium = lb_method.get_equilibrium()
    equilibrium = equilibrium.new_with_substitutions({a: b for a, b in zip(vel_symbols, vel_arr_symbols)})
Markus Holzer's avatar
Markus Holzer committed
47
    _, _, equilibrium = add_types(equilibrium.main_assignments, dtype_string, False)
48
49
50
    equilibrium = sp.Matrix([e.rhs for e in equilibrium])

    symmetric_equilibrium = get_symmetric_part(equilibrium, vel_arr_symbols)
Markus Holzer's avatar
Markus Holzer committed
51
    symmetric_equilibrium = symmetric_equilibrium.subs(sp.Rational(1, 2), cast_func(sp.Rational(1, 2), dtype_string))
52
53
54
55
56
57
    asymmetric_equilibrium = sp.expand(equilibrium - symmetric_equilibrium)

    force_model = lb_method.force_model
    macroscopic_velocity_shift = None
    if force_model:
        if hasattr(force_model, 'macroscopic_velocity_shift'):
Markus Holzer's avatar
Markus Holzer committed
58
59
60
            macroscopic_velocity_shift = [expression_to_code(e.subs(sp.Rational(1, 2), cast_func(sp.Rational(1, 2),
                                                                                                 dtype_string)),
                                                             "lm.", ['rho'], dtype=dtype_string)
61
62
63
64
                                          for e in force_model.macroscopic_velocity_shift(rho_sym)]

    cqc = lb_method.conserved_quantity_computation

Markus Holzer's avatar
Markus Holzer committed
65
66
    eq_input_from_input_eqs = cqc.equilibrium_input_equations_from_init_values(sp.Symbol('rho_in'), vel_arr_symbols)
    density_velocity_setter_macroscopic_values = equations_to_code(eq_input_from_input_eqs, dtype=dtype_string,
67
68
69
70
71
72
73
74
                                                                   variables_without_prefix=['rho_in', 'u'])
    momentum_density_getter = cqc.output_equations_from_pdfs(pdfs_sym, {'density': rho_sym,
                                                                        'momentum_density': momentum_density_symbols})
    constant_suffix = "f" if is_float else ""

    required_headers = get_headers(stream_collide_ast)

    if refinement_scaling:
Markus Holzer's avatar
Markus Holzer committed
75
        refinement_scaling_info = [(e0, e1, expression_to_code(e2, '', dtype=dtype_string)) for e0, e1, e2 in
76
                                   refinement_scaling.scaling_info]
Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
77
78
79
80
81
82
83
84
85
        # append '_' to entries since they are used as members
        for i in range(len(refinement_scaling_info)):
            updated_entry = (refinement_scaling_info[i][0],
                             refinement_scaling_info[i][1].replace(refinement_scaling_info[i][1],
                                                                   refinement_scaling_info[i][1] + '_'),
                             refinement_scaling_info[i][2].replace(refinement_scaling_info[i][1],
                                                                   refinement_scaling_info[i][1] + '_'),
                             )
            refinement_scaling_info[i] = updated_entry
86
87
88
89
90
91
    else:
        refinement_scaling_info = None

    jinja_context = {
        'class_name': class_name,
        'stencil_name': stencil_name,
Markus Holzer's avatar
Markus Holzer committed
92
        'communication_stencil_name': communication_stencil_name,
93
94
95
96
        'D': lb_method.dim,
        'Q': len(lb_method.stencil),
        'compressible': lb_method.conserved_quantity_computation.compressible,
        'weights': ",".join(str(w.evalf()) + constant_suffix for w in lb_method.weights),
97
        'inverse_weights': ",".join(str((1 / w).evalf()) + constant_suffix for w in lb_method.weights),
98
99
100
101
102
103
104
105

        'equilibrium_from_direction': stencil_switch_statement(lb_method.stencil, equilibrium),
        'symmetric_equilibrium_from_direction': stencil_switch_statement(lb_method.stencil, symmetric_equilibrium),
        'asymmetric_equilibrium_from_direction': stencil_switch_statement(lb_method.stencil, asymmetric_equilibrium),
        'equilibrium': [cpp_printer.doprint(e) for e in equilibrium],

        'macroscopic_velocity_shift': macroscopic_velocity_shift,
        'density_getters': equations_to_code(cqc.output_equations_from_pdfs(pdfs_sym, {"density": rho_sym}),
Markus Holzer's avatar
Markus Holzer committed
106
                                             variables_without_prefix=[e.name for e in pdfs_sym], dtype=dtype_string),
107
        'momentum_density_getter': equations_to_code(momentum_density_getter, variables_without_prefix=pdfs_sym,
Markus Holzer's avatar
Markus Holzer committed
108
                                                     dtype=dtype_string),
109
110
111
112
113
114
115
116
117
118
        'density_velocity_setter_macroscopic_values': density_velocity_setter_macroscopic_values,

        'refinement_scaling_info': refinement_scaling_info,

        'stream_collide_kernel': KernelInfo(stream_collide_ast, ['pdfs_tmp'], [('pdfs', 'pdfs_tmp')], []),
        'collide_kernel': KernelInfo(collide_ast, [], [], []),
        'stream_kernel': KernelInfo(stream_ast, ['pdfs_tmp'], [('pdfs', 'pdfs_tmp')], []),
        'target': 'cpu',
        'namespace': 'lbm',
        'headers': required_headers,
119
120
121
        'need_block_offsets': [
            'block_offset_{}'.format(i) in [param.symbol.name for param in stream_collide_ast.get_parameters()] for i in
            range(3)],
122
123
124
125
126
127
128
129
    }

    env = Environment(loader=PackageLoader('lbmpy_walberla'), undefined=StrictUndefined)
    add_pystencils_filters_to_jinja_env(env)

    header = env.get_template('LatticeModel.tmpl.h').render(**jinja_context)
    source = env.get_template('LatticeModel.tmpl.cpp').render(**jinja_context)

Markus Holzer's avatar
Markus Holzer committed
130
131
    generation_context.write_file(f"{class_name}.h", header)
    generation_context.write_file(f"{class_name}.cpp", source)
132
133


134
def generate_lattice_model(generation_context, class_name, collision_rule, field_layout='zyxf', refinement_scaling=None,
135
136
137
138
139
140
141
142
143
144
                           **create_kernel_params):
    # usually a numpy layout is chosen by default i.e. xyzf - which is bad for waLBerla where at least the spatial
    # coordinates should be ordered in reverse direction i.e. zyx
    is_float = not generation_context.double_accuracy
    dtype = np.float32 if is_float else np.float64
    lb_method = collision_rule.method

    q = len(lb_method.stencil)
    dim = lb_method.dim

Christoph Rettinger's avatar
Christoph Rettinger committed
145
146
147
148
    create_kernel_params = default_create_kernel_parameters(generation_context, create_kernel_params)
    if create_kernel_params['target'] == 'gpu':
        raise ValueError("Lattice Models can only be generated for CPUs. To generate LBM on GPUs use sweeps directly")

149
150
151
152
153
    if field_layout == 'fzyx':
        create_kernel_params['cpu_vectorize_info']['assume_inner_stride_one'] = True
    elif field_layout == 'zyxf':
        create_kernel_params['cpu_vectorize_info']['assume_inner_stride_one'] = False

154
    src_field = ps.Field.create_generic('pdfs', dim, dtype, index_dimensions=1, layout=field_layout, index_shape=(q,))
155
156
    dst_field = ps.Field.create_generic('pdfs_tmp', dim, dtype, index_dimensions=1, layout=field_layout,
                                        index_shape=(q,))
157
158
159
160

    stream_collide_update_rule = create_lbm_kernel(collision_rule, src_field, dst_field, StreamPullTwoFieldsAccessor())
    stream_collide_ast = create_kernel(stream_collide_update_rule, **create_kernel_params)
    stream_collide_ast.function_name = 'kernel_streamCollide'
161
    stream_collide_ast.assumed_inner_stride_one = create_kernel_params['cpu_vectorize_info']['assume_inner_stride_one']
162
163
164
165

    collide_update_rule = create_lbm_kernel(collision_rule, src_field, dst_field, CollideOnlyInplaceAccessor())
    collide_ast = create_kernel(collide_update_rule, **create_kernel_params)
    collide_ast.function_name = 'kernel_collide'
166
    collide_ast.assumed_inner_stride_one = create_kernel_params['cpu_vectorize_info']['assume_inner_stride_one']
167

168
169
    stream_update_rule = create_stream_pull_only_kernel(lb_method.stencil, None, 'pdfs', 'pdfs_tmp', field_layout,
                                                        dtype)
170
171
    stream_ast = create_kernel(stream_update_rule, **create_kernel_params)
    stream_ast.function_name = 'kernel_stream'
172
    stream_ast.assumed_inner_stride_one = create_kernel_params['cpu_vectorize_info']['assume_inner_stride_one']
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    __lattice_model(generation_context, class_name, lb_method, stream_collide_ast, collide_ast, stream_ast,
                    refinement_scaling)


class RefinementScaling:
    level_scale_factor = sp.Symbol("level_scale_factor")

    def __init__(self):
        self.scaling_info = []

    def add_standard_relaxation_rate_scaling(self, viscosity_relaxation_rate):
        self.add_scaling(viscosity_relaxation_rate, relaxation_rate_scaling)

    def add_force_scaling(self, force_parameter):
        self.add_scaling(force_parameter, lambda param, factor: param * factor)

    def add_scaling(self, parameter, scaling_rule):
        """
        Adds a scaling rule, how parameters on refined blocks are modified

        :param parameter: parameter to modify: may either be a Field, Field.Access or a Symbol
        :param scaling_rule: function taking the parameter to be scaled as symbol and the scaling factor i.e.
                            how much finer the current block is compared to coarsest resolution
        """
        if isinstance(parameter, Field):
            field = parameter
            name = field.name
            if field.index_dimensions > 0:
                scaling_type = 'field_with_f'
                field_access = field(sp.Symbol("f"))
            else:
                scaling_type = 'field_xyz'
                field_access = field.center
            expr = scaling_rule(field_access, self.level_scale_factor)
            self.scaling_info.append((scaling_type, name, expr))
        elif isinstance(parameter, Field.Access):
            field_access = parameter
            expr = scaling_rule(field_access, self.level_scale_factor)
            name = field_access.field.name
            self.scaling_info.append(('field_xyz', name, expr))
        elif isinstance(parameter, sp.Symbol):
            expr = scaling_rule(parameter, self.level_scale_factor)
            self.scaling_info.append(('normal', parameter.name, expr))
        elif isinstance(parameter, list) or isinstance(parameter, tuple):
            for p in parameter:
                self.add_scaling(p, scaling_rule)
        else:
            raise ValueError("Invalid value for viscosity_relaxation_rate")


# ------------------------------------------ Internal ------------------------------------------------------------------


def stencil_switch_statement(stencil, values):
    template = Template("""
    using namespace stencil;
    switch( direction ) {
        {% for direction_name, value in dir_to_value_dict.items() -%}
            case {{direction_name}}: return {{value}};
        {% endfor -%}
        default:
            WALBERLA_ABORT("Invalid Direction");
    }
    """)

    dir_to_value_dict = {offset_to_direction_string(d): cpp_printer.doprint(v) for d, v in zip(stencil, values)}
    return template.render(dir_to_value_dict=dir_to_value_dict, undefined=StrictUndefined)


Markus Holzer's avatar
Markus Holzer committed
242
243
244
def field_and_symbol_substitute(expr, variable_prefix="lm.", variables_without_prefix=None):
    if variables_without_prefix is None:
        variables_without_prefix = []
245
246
    variables_without_prefix = [v.name if isinstance(v, sp.Symbol) else v for v in variables_without_prefix]
    substitutions = {}
Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
247
248
249
250
251
    # check for member access
    if variable_prefix.endswith('.'):
        postfix = '_'
    else:
        postfix = ''
252
253
254
255
    for sym in expr.atoms(sp.Symbol):
        if isinstance(sym, Field.Access):
            fa = sym
            prefix = "" if fa.field.name in variables_without_prefix else variable_prefix
Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
256
257
258
259
            if prefix.endswith('.'):
                postfix2 = '_'
            else:
                postfix2 = ''
260
            if fa.field.index_dimensions == 0:
Markus Holzer's avatar
Markus Holzer committed
261
                substitutions[fa] = sp.Symbol(f"{prefix}{fa.field.name + postfix2}->get(x,y,z)")
262
263
            else:
                assert fa.field.index_dimensions == 1, "walberla supports only 0 or 1 index dimensions"
Markus Holzer's avatar
Markus Holzer committed
264
                substitutions[fa] = sp.Symbol(f"{prefix}{fa.field.name + postfix2}->get(x,y,z,{fa.index[0]})")
265
266
        else:
            if sym.name not in variables_without_prefix:
Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
267
                substitutions[sym] = sp.Symbol(variable_prefix + sym.name + postfix)
268
269
270
    return expr.subs(substitutions)


Markus Holzer's avatar
Markus Holzer committed
271
def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=None, dtype="double"):
272
273
274
    """
    Takes a sympy expression and creates a C code string from it. Replaces field accesses by
    walberla field accesses i.e. field_W^1 -> field->get(-1, 0, 0, 1)
Markus Holzer's avatar
Markus Holzer committed
275
    :param dtype: default data type used for numbers in the code
276
277
278
279
280
281
    :param expr: sympy expression
    :param variable_prefix: all variables (and field) are prefixed with this string
                           this is used for member variables mostly
    :param variables_without_prefix: this variables are not prefixed
    :return: code string
    """
Markus Holzer's avatar
Markus Holzer committed
282
283
    if variables_without_prefix is None:
        variables_without_prefix = []
284
285
    return cpp_printer.doprint(
        type_expr(field_and_symbol_substitute(expr, variable_prefix, variables_without_prefix), dtype=dtype))
286

Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
287

288
def type_expr(eq, dtype):
Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
289
290
291
292
293
294
295
296
    def recurse(expr):
        for i in range(len(expr.args)):
            if expr.args[i] == sp.Rational or expr.args[i] == sp.Float:
                expr.args[i] = type_all_numbers(expr.args[i], dtype=dtype)
            else:
                recurse(expr.args[i])

    recurse(eq)
297
298
    return eq.subs({s: TypedSymbol(s.name, dtype) for s in eq.atoms(sp.Symbol)})

Dominik Thoennes's avatar
fix CI    
Dominik Thoennes committed
299

Markus Holzer's avatar
Markus Holzer committed
300
301
302
def equations_to_code(equations, variable_prefix="lm.", variables_without_prefix=None, dtype="double"):
    if variables_without_prefix is None:
        variables_without_prefix = []
303
304
305
306
307
308
309
310
    if isinstance(equations, AssignmentCollection):
        equations = equations.all_assignments

    variables_without_prefix = list(variables_without_prefix)
    c_backend = CBackend()
    result = []
    left_hand_side_names = [e.lhs.name for e in equations]
    for eq in equations:
311
        assignment = SympyAssignment(type_expr(eq.lhs, dtype=dtype),
312
                                     type_expr(field_and_symbol_substitute(eq.rhs, variable_prefix,
313
314
315
                                                                           variables_without_prefix
                                                                           + left_hand_side_names),
                                               dtype=dtype))
316
317
318
319
320
321
322
323
        result.append(c_backend(assignment))
    return "\n".join(result)


def get_stencil_name(stencil):
    for name in ('D2Q9', 'D3Q15', 'D3Q19', 'D3Q27'):
        if have_same_entries(stencil, get_stencil(name, 'walberla')):
            return name