kernelcreation.py 11 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
from typing import List, Union

Martin Bauer's avatar
Martin Bauer committed
3
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
4

Martin Bauer's avatar
Martin Bauer committed
5
import pystencils.astnodes as ast
Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.assignment import Assignment
Martin Bauer's avatar
Martin Bauer committed
7
8
9
10
11
12
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.cpujit import make_python_function
from pystencils.data_types import BasicType, StructType, TypedSymbol, create_type
from pystencils.field import Field, FieldType
from pystencils.transformations import (
    add_types, filtered_tree_iteration, get_base_buffer_index, get_optimal_loop_ordering,
13
14
    implement_interpolations, make_loop_over_domain, move_constants_before_loop,
    parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, split_inner_loop)
Martin Bauer's avatar
Martin Bauer committed
15

Martin Bauer's avatar
Martin Bauer committed
16
AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
Martin Bauer's avatar
Martin Bauer committed
17

Martin Bauer's avatar
Martin Bauer committed
18
19

def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
20
21
                  split_groups=(), iteration_slice=None, ghost_layers=None,
                  skip_independence_check=False) -> KernelFunction:
Martin Bauer's avatar
Martin Bauer committed
22
    """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Martin Bauer's avatar
Martin Bauer committed
23
24
25

    Loops are created according to the field accesses in the equations.

Martin Bauer's avatar
Martin Bauer committed
26
27
28
29
30
31
32
33
34
35
36
37
38
    Args:
        assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
        Defining the update rules of the kernel
        function_name: name of the generated function - only important if generated code is written out
        type_info: a map from symbol name to a C type specifier. If not specified all symbols are assumed to
                   be of type 'double' except symbols which occur on the left hand side of equations where the
                   right hand side is a sympy Boolean which are assumed to be 'bool' .
        split_groups: Specification on how to split up inner loop into multiple loops. For details see
                      transformation :func:`pystencils.transformation.split_inner_loop`
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
                     if None, the number of ghost layers is determined automatically and assumed to be equal for a
                     all dimensions
39
40
        skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
                                 periodicity kernel, that access the field outside the iteration bounds. Use with care!
Martin Bauer's avatar
Martin Bauer committed
41
42
43

    Returns:
        AST node representing a function, that can be printed as C or CUDA code
Martin Bauer's avatar
Martin Bauer committed
44
    """
45
    def type_symbol(term):
Martin Bauer's avatar
Martin Bauer committed
46
47
48
        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
            return term
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
49
            if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
Martin Bauer's avatar
Martin Bauer committed
50
                return TypedSymbol(term.name, create_type(type_info))
51
            else:
Martin Bauer's avatar
Martin Bauer committed
52
                return TypedSymbol(term.name, type_info[term.name])
Martin Bauer's avatar
Martin Bauer committed
53
54
55
        else:
            raise ValueError("Term has to be field access or symbol")

56
    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
Martin Bauer's avatar
Martin Bauer committed
57
58
    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])
Martin Bauer's avatar
Martin Bauer committed
59

Martin Bauer's avatar
Martin Bauer committed
60
    buffers = set([f for f in all_fields if FieldType.is_buffer(f)])
Martin Bauer's avatar
Martin Bauer committed
61
    fields_without_buffers = all_fields - buffers
62

Martin Bauer's avatar
Martin Bauer committed
63
    body = ast.Block(assignments)
Martin Bauer's avatar
Martin Bauer committed
64
    loop_order = get_optimal_loop_ordering(fields_without_buffers)
65
66
67
68
    loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
                                                        ghost_layers=ghost_layers, loop_order=loop_order)
    ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function,
                              ghost_layers=ghost_layer_info, function_name=function_name)
69
    implement_interpolations(body)
Martin Bauer's avatar
Martin Bauer committed
70

Martin Bauer's avatar
Martin Bauer committed
71
    if split_groups:
Martin Bauer's avatar
Martin Bauer committed
72
        typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
Martin Bauer's avatar
Martin Bauer committed
73
        split_inner_loop(ast_node, typed_split_groups)
Martin Bauer's avatar
Martin Bauer committed
74

Martin Bauer's avatar
Martin Bauer committed
75
    base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
Martin Bauer's avatar
Martin Bauer committed
76
77
    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
                                                             field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
78
                         for field in fields_without_buffers}
79

Martin Bauer's avatar
Martin Bauer committed
80
81
    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
                                                                    field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
82
83
                                for field in buffers}
    base_pointer_info.update(buffer_base_pointer_info)
84

Martin Bauer's avatar
Martin Bauer committed
85
86
87
88
89
    if any(FieldType.is_buffer(f) for f in all_fields):
        resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
    resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
    move_constants_before_loop(ast_node)
    return ast_node
Martin Bauer's avatar
Martin Bauer committed
90
91


Martin Bauer's avatar
Martin Bauer committed
92
93
def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
                          type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction:
94
    """
Martin Bauer's avatar
Martin Bauer committed
95
    Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
96
97
    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.

Martin Bauer's avatar
Martin Bauer committed
98
    The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type.
99
    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
Martin Bauer's avatar
Martin Bauer committed
100
    'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
101
102
    example boundary parameters.

Martin Bauer's avatar
Martin Bauer committed
103
104
105
106
107
108
    Args:
        assignments: list of assignments
        index_fields: list of index fields, i.e. 1D fields with struct data type
        type_info: see documentation of :func:`create_kernel`
        function_name: see documentation of :func:`create_kernel`
        coordinate_names: name of the coordinate fields in the struct data type
109
    """
110
    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
Martin Bauer's avatar
Martin Bauer committed
111
112
113
    all_fields = fields_read.union(fields_written)

    for index_field in index_fields:
Martin Bauer's avatar
Martin Bauer committed
114
        index_field.field_type = FieldType.INDEXED
Martin Bauer's avatar
Martin Bauer committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"

    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatial_coordinates = list(spatial_coordinates)[0]

    def get_coordinate_symbol_assignment(name):
        for idx_field in index_fields:
            assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
            data_type = idx_field.dtype
            if data_type.has_element(name):
                rhs = idx_field[0](name)
                lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
130
131
132
                return SympyAssignment(lhs, rhs)
        raise ValueError("Index %s not found in any of the passed index fields" % (name,))

Martin Bauer's avatar
Martin Bauer committed
133
134
135
136
    coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
                                     for n in coordinate_names[:spatial_coordinates]]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
    assignments = coordinate_symbol_assignments + assignments
137
138

    # make 1D loop over index fields
Martin Bauer's avatar
Martin Bauer committed
139
140
    loop_body = Block([])
    loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0])
141

142
143
    implement_interpolations(loop_node)

144
    for assignment in assignments:
Martin Bauer's avatar
Martin Bauer committed
145
        loop_body.append(assignment)
146

Martin Bauer's avatar
Martin Bauer committed
147
    function_body = Block([loop_node])
148
149
    ast_node = KernelFunction(function_body, "cpu", "c", make_python_function,
                              ghost_layers=None, function_name=function_name)
150

Martin Bauer's avatar
Martin Bauer committed
151
    fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}
152

Martin Bauer's avatar
Martin Bauer committed
153
154
155
156
    read_only_fields = set([f.name for f in fields_read - fields_written])
    resolve_field_accesses(ast_node, read_only_fields, field_to_fixed_coordinates=fixed_coordinate_mapping)
    move_constants_before_loop(ast_node)
    return ast_node
157

Martin Bauer's avatar
Martin Bauer committed
158

159
def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, assume_single_outer_loop=True):
Martin Bauer's avatar
Martin Bauer committed
160
161
162
163
164
165
    """Parallelize the outer loop with OpenMP.

    Args:
        ast_node: abstract syntax tree created e.g. by :func:`create_kernel`
        schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic'
        num_threads: explicitly specify number of threads
Martin Bauer's avatar
Martin Bauer committed
166
        collapse: number of nested loops to include in parallel region (see OpenMP collapse)
167
168
        assume_single_outer_loop: if True an exception is raised if multiple outer loops are detected for all but
                                  optimized staggered kernels the single-outer-loop assumption should be true
Martin Bauer's avatar
Martin Bauer committed
169
    """
Martin Bauer's avatar
Martin Bauer committed
170
    if not num_threads:
171
172
        return

Martin Bauer's avatar
Martin Bauer committed
173
174
175
176
177
    assert type(ast_node) is ast.KernelFunction
    body = ast_node.body
    threads_clause = "" if num_threads and isinstance(num_threads, bool) else " num_threads(%s)" % (num_threads,)
    wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
    body.append(wrapper_block)
Martin Bauer's avatar
Martin Bauer committed
178

179
180
    outer_loops = [l for l in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment)
                   if l.is_outermost_loop]
Martin Bauer's avatar
Martin Bauer committed
181
    assert outer_loops, "No outer loop found"
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    if assume_single_outer_loop and len(outer_loops) > 1:
        raise ValueError("More than one outer loop found, only one outer loop expected")

    for loop_to_parallelize in outer_loops:
        try:
            loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
        except TypeError:
            loop_range = None

        if num_threads is None:
            import multiprocessing
            num_threads = multiprocessing.cpu_count()

        if loop_range is not None and loop_range < num_threads and not collapse:
            contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
            if len(contained_loops) == 1:
                contained_loop = contained_loops[0]
                try:
                    contained_loop_range = int(contained_loop.stop - contained_loop.start)
                    if contained_loop_range > loop_range:
                        loop_to_parallelize = contained_loop
                except TypeError:
                    pass

        prefix = "#pragma omp for schedule(%s)" % (schedule,)
        if collapse:
            prefix += " collapse(%d)" % (collapse, )
        loop_to_parallelize.prefix_lines.append(prefix)