kernelcreation.py 11.1 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
from typing import List, Union

Martin Bauer's avatar
Martin Bauer committed
3
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
4

Martin Bauer's avatar
Martin Bauer committed
5
import pystencils.astnodes as ast
Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
7
8
9
10
11
12
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.cpujit import make_python_function
from pystencils.data_types import BasicType, StructType, TypedSymbol, create_type
from pystencils.field import Field, FieldType
from pystencils.transformations import (
    add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering,
13
14
    implement_interpolations, make_loop_over_domain, move_constants_before_loop,
    parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, split_inner_loop)
Martin Bauer's avatar
Martin Bauer committed
15

Martin Bauer's avatar
Martin Bauer committed
16
AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
Martin Bauer's avatar
Martin Bauer committed
17

Martin Bauer's avatar
Martin Bauer committed
18
19

def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
20
21
                  split_groups=(), iteration_slice=None, ghost_layers=None,
                  skip_independence_check=False) -> KernelFunction:
Martin Bauer's avatar
Martin Bauer committed
22
    """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Martin Bauer's avatar
Martin Bauer committed
23
24
25

    Loops are created according to the field accesses in the equations.

Martin Bauer's avatar
Martin Bauer committed
26
27
28
29
30
31
32
33
34
35
36
    Args:
        assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
        Defining the update rules of the kernel
        function_name: name of the generated function - only important if generated code is written out
        type_info: a map from symbol name to a C type specifier. If not specified all symbols are assumed to
                   be of type 'double' except symbols which occur on the left hand side of equations where the
                   right hand side is a sympy Boolean which are assumed to be 'bool' .
        split_groups: Specification on how to split up inner loop into multiple loops. For details see
                      transformation :func:`pystencils.transformation.split_inner_loop`
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
37
                      that should be excluded from the iteration.
Martin Bauer's avatar
Martin Bauer committed
38
39
                     if None, the number of ghost layers is determined automatically and assumed to be equal for a
                     all dimensions
40
41
        skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
                                 periodicity kernel, that access the field outside the iteration bounds. Use with care!
Martin Bauer's avatar
Martin Bauer committed
42
43
44

    Returns:
        AST node representing a function, that can be printed as C or CUDA code
Martin Bauer's avatar
Martin Bauer committed
45
    """
46
    def type_symbol(term):
Martin Bauer's avatar
Martin Bauer committed
47
48
49
        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
            return term
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
50
            if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
Martin Bauer's avatar
Martin Bauer committed
51
                return TypedSymbol(term.name, create_type(type_info))
52
            else:
Martin Bauer's avatar
Martin Bauer committed
53
                return TypedSymbol(term.name, type_info[term.name])
Martin Bauer's avatar
Martin Bauer committed
54
55
56
        else:
            raise ValueError("Term has to be field access or symbol")

57
    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
Martin Bauer's avatar
Martin Bauer committed
58
59
    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])
Martin Bauer's avatar
Martin Bauer committed
60

Martin Bauer's avatar
Martin Bauer committed
61
    buffers = set([f for f in all_fields if FieldType.is_buffer(f)])
Martin Bauer's avatar
Martin Bauer committed
62
    fields_without_buffers = all_fields - buffers
63

Martin Bauer's avatar
Martin Bauer committed
64
    body = ast.Block(assignments)
Martin Bauer's avatar
Martin Bauer committed
65
    loop_order = get_optimal_loop_ordering(fields_without_buffers)
66
67
68
    loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
                                                        ghost_layers=ghost_layers, loop_order=loop_order)
    ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function,
69
                              ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)
70
    implement_interpolations(body)
Martin Bauer's avatar
Martin Bauer committed
71

Martin Bauer's avatar
Martin Bauer committed
72
    if split_groups:
Martin Bauer's avatar
Martin Bauer committed
73
        typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
Martin Bauer's avatar
Martin Bauer committed
74
        split_inner_loop(ast_node, typed_split_groups)
Martin Bauer's avatar
Martin Bauer committed
75

Martin Bauer's avatar
Martin Bauer committed
76
    base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
Martin Bauer's avatar
Martin Bauer committed
77
78
    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
                                                             field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
79
                         for field in fields_without_buffers}
80

Martin Bauer's avatar
Martin Bauer committed
81
82
    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
                                                                    field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
83
84
                                for field in buffers}
    base_pointer_info.update(buffer_base_pointer_info)
85

Martin Bauer's avatar
Martin Bauer committed
86
87
88
89
90
    if any(FieldType.is_buffer(f) for f in all_fields):
        resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
    resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
    move_constants_before_loop(ast_node)
    return ast_node
Martin Bauer's avatar
Martin Bauer committed
91
92


Martin Bauer's avatar
Martin Bauer committed
93
94
def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
                          type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction:
95
    """
Martin Bauer's avatar
Martin Bauer committed
96
    Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
97
98
    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.

Martin Bauer's avatar
Martin Bauer committed
99
    The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type.
100
    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
Martin Bauer's avatar
Martin Bauer committed
101
    'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
102
103
    example boundary parameters.

Martin Bauer's avatar
Martin Bauer committed
104
105
106
107
108
109
    Args:
        assignments: list of assignments
        index_fields: list of index fields, i.e. 1D fields with struct data type
        type_info: see documentation of :func:`create_kernel`
        function_name: see documentation of :func:`create_kernel`
        coordinate_names: name of the coordinate fields in the struct data type
110
    """
111
    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
Martin Bauer's avatar
Martin Bauer committed
112
113
114
    all_fields = fields_read.union(fields_written)

    for index_field in index_fields:
Martin Bauer's avatar
Martin Bauer committed
115
        index_field.field_type = FieldType.INDEXED
Martin Bauer's avatar
Martin Bauer committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"

    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatial_coordinates = list(spatial_coordinates)[0]

    def get_coordinate_symbol_assignment(name):
        for idx_field in index_fields:
            assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
            data_type = idx_field.dtype
            if data_type.has_element(name):
                rhs = idx_field[0](name)
                lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
131
                return SympyAssignment(lhs, rhs)
132
        raise ValueError(f"Index {name} not found in any of the passed index fields")
133

Martin Bauer's avatar
Martin Bauer committed
134
135
136
137
    coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
                                     for n in coordinate_names[:spatial_coordinates]]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
    assignments = coordinate_symbol_assignments + assignments
138
139

    # make 1D loop over index fields
Martin Bauer's avatar
Martin Bauer committed
140
141
    loop_body = Block([])
    loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0])
142

143
144
    implement_interpolations(loop_node)

145
    for assignment in assignments:
Martin Bauer's avatar
Martin Bauer committed
146
        loop_body.append(assignment)
147

Martin Bauer's avatar
Martin Bauer committed
148
    function_body = Block([loop_node])
149
    ast_node = KernelFunction(function_body, "cpu", "c", make_python_function,
150
                              ghost_layers=None, function_name=function_name, assignments=assignments)
151

Martin Bauer's avatar
Martin Bauer committed
152
    fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}
153

Martin Bauer's avatar
Martin Bauer committed
154
155
156
157
    read_only_fields = set([f.name for f in fields_read - fields_written])
    resolve_field_accesses(ast_node, read_only_fields, field_to_fixed_coordinates=fixed_coordinate_mapping)
    move_constants_before_loop(ast_node)
    return ast_node
158

Martin Bauer's avatar
Martin Bauer committed
159

160
def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, assume_single_outer_loop=True):
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
    """Parallelize the outer loop with OpenMP.

    Args:
        ast_node: abstract syntax tree created e.g. by :func:`create_kernel`
        schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic'
        num_threads: explicitly specify number of threads
Martin Bauer's avatar
Martin Bauer committed
167
        collapse: number of nested loops to include in parallel region (see OpenMP collapse)
168
169
        assume_single_outer_loop: if True an exception is raised if multiple outer loops are detected for all but
                                  optimized staggered kernels the single-outer-loop assumption should be true
Martin Bauer's avatar
Martin Bauer committed
170
    """
Martin Bauer's avatar
Martin Bauer committed
171
    if not num_threads:
172
173
        return

Martin Bauer's avatar
Martin Bauer committed
174
175
    assert type(ast_node) is ast.KernelFunction
    body = ast_node.body
176
    threads_clause = "" if num_threads and isinstance(num_threads, bool) else f" num_threads({num_threads})"
Martin Bauer's avatar
Martin Bauer committed
177
178
    wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
    body.append(wrapper_block)
Martin Bauer's avatar
Martin Bauer committed
179

Markus Holzer's avatar
Markus Holzer committed
180
181
    outer_loops = [l for l in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment)
                   if l.is_outermost_loop]
Martin Bauer's avatar
Martin Bauer committed
182
    assert outer_loops, "No outer loop found"
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    if assume_single_outer_loop and len(outer_loops) > 1:
        raise ValueError("More than one outer loop found, only one outer loop expected")

    for loop_to_parallelize in outer_loops:
        try:
            loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
        except TypeError:
            loop_range = None

        if num_threads is None:
            import multiprocessing
            num_threads = multiprocessing.cpu_count()

        if loop_range is not None and loop_range < num_threads and not collapse:
Markus Holzer's avatar
Markus Holzer committed
197
            contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
198
199
200
201
202
203
204
205
206
            if len(contained_loops) == 1:
                contained_loop = contained_loops[0]
                try:
                    contained_loop_range = int(contained_loop.stop - contained_loop.start)
                    if contained_loop_range > loop_range:
                        loop_to_parallelize = contained_loop
                except TypeError:
                    pass

207
        prefix = f"#pragma omp for schedule({schedule})"
208
209
210
        if collapse:
            prefix += " collapse(%d)" % (collapse, )
        loop_to_parallelize.prefix_lines.append(prefix)