kernelcreation.py 10.8 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
2
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
3
from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
4
    add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
5
    split_inner_loop, get_base_buffer_index, filtered_tree_iteration
Martin Bauer's avatar
Martin Bauer committed
6
from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
7
from pystencils.field import Field, FieldType
Martin Bauer's avatar
Martin Bauer committed
8
import pystencils.astnodes as ast
Martin Bauer's avatar
Martin Bauer committed
9
10
11
from pystencils.cpu.cpujit import make_python_function
from pystencils.assignment import Assignment
from typing import List, Union
Martin Bauer's avatar
Martin Bauer committed
12

Martin Bauer's avatar
Martin Bauer committed
13
AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
Martin Bauer's avatar
Martin Bauer committed
14

Martin Bauer's avatar
Martin Bauer committed
15
16

def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
17
18
                  split_groups=(), iteration_slice=None, ghost_layers=None,
                  skip_independence_check=False) -> KernelFunction:
Martin Bauer's avatar
Martin Bauer committed
19
    """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Martin Bauer's avatar
Martin Bauer committed
20
21
22

    Loops are created according to the field accesses in the equations.

Martin Bauer's avatar
Martin Bauer committed
23
24
25
26
27
28
29
30
31
32
33
34
35
    Args:
        assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
        Defining the update rules of the kernel
        function_name: name of the generated function - only important if generated code is written out
        type_info: a map from symbol name to a C type specifier. If not specified all symbols are assumed to
                   be of type 'double' except symbols which occur on the left hand side of equations where the
                   right hand side is a sympy Boolean which are assumed to be 'bool' .
        split_groups: Specification on how to split up inner loop into multiple loops. For details see
                      transformation :func:`pystencils.transformation.split_inner_loop`
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
                     if None, the number of ghost layers is determined automatically and assumed to be equal for a
                     all dimensions
36
37
        skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
                                 periodicity kernel, that access the field outside the iteration bounds. Use with care!
Martin Bauer's avatar
Martin Bauer committed
38
39
40

    Returns:
        AST node representing a function, that can be printed as C or CUDA code
Martin Bauer's avatar
Martin Bauer committed
41
    """
42

43
    def type_symbol(term):
Martin Bauer's avatar
Martin Bauer committed
44
45
46
        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
            return term
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
47
48
            if not hasattr(type_info, '__getitem__'):
                return TypedSymbol(term.name, create_type(type_info))
49
            else:
Martin Bauer's avatar
Martin Bauer committed
50
                return TypedSymbol(term.name, type_info[term.name])
Martin Bauer's avatar
Martin Bauer committed
51
52
53
        else:
            raise ValueError("Term has to be field access or symbol")

54
    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
Martin Bauer's avatar
Martin Bauer committed
55
56
    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])
Martin Bauer's avatar
Martin Bauer committed
57

Martin Bauer's avatar
Martin Bauer committed
58
    buffers = set([f for f in all_fields if FieldType.is_buffer(f)])
Martin Bauer's avatar
Martin Bauer committed
59
    fields_without_buffers = all_fields - buffers
60

Martin Bauer's avatar
Martin Bauer committed
61
    body = ast.Block(assignments)
Martin Bauer's avatar
Martin Bauer committed
62
    loop_order = get_optimal_loop_ordering(fields_without_buffers)
63
64
65
66
    loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
                                                        ghost_layers=ghost_layers, loop_order=loop_order)
    ast_node = KernelFunction(loop_node, 'cpu', 'c', compile_function=make_python_function,
                              ghost_layers=ghost_layer_info, function_name=function_name)
Martin Bauer's avatar
Martin Bauer committed
67

Martin Bauer's avatar
Martin Bauer committed
68
    if split_groups:
Martin Bauer's avatar
Martin Bauer committed
69
        typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
Martin Bauer's avatar
Martin Bauer committed
70
        split_inner_loop(ast_node, typed_split_groups)
Martin Bauer's avatar
Martin Bauer committed
71

Martin Bauer's avatar
Martin Bauer committed
72
    base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
Martin Bauer's avatar
Martin Bauer committed
73
74
    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
                                                             field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
75
                         for field in fields_without_buffers}
76

Martin Bauer's avatar
Martin Bauer committed
77
78
    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
                                                                    field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
79
80
                                for field in buffers}
    base_pointer_info.update(buffer_base_pointer_info)
81

Martin Bauer's avatar
Martin Bauer committed
82
83
84
85
86
    if any(FieldType.is_buffer(f) for f in all_fields):
        resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
    resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
    move_constants_before_loop(ast_node)
    return ast_node
Martin Bauer's avatar
Martin Bauer committed
87
88


Martin Bauer's avatar
Martin Bauer committed
89
90
def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
                          type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction:
91
    """
Martin Bauer's avatar
Martin Bauer committed
92
    Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
93
94
    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.

Martin Bauer's avatar
Martin Bauer committed
95
    The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type.
96
    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
Martin Bauer's avatar
Martin Bauer committed
97
    'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
98
99
    example boundary parameters.

Martin Bauer's avatar
Martin Bauer committed
100
101
102
103
104
105
    Args:
        assignments: list of assignments
        index_fields: list of index fields, i.e. 1D fields with struct data type
        type_info: see documentation of :func:`create_kernel`
        function_name: see documentation of :func:`create_kernel`
        coordinate_names: name of the coordinate fields in the struct data type
106
    """
107
    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
Martin Bauer's avatar
Martin Bauer committed
108
109
110
    all_fields = fields_read.union(fields_written)

    for index_field in index_fields:
Martin Bauer's avatar
Martin Bauer committed
111
        index_field.field_type = FieldType.INDEXED
Martin Bauer's avatar
Martin Bauer committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"

    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatial_coordinates = list(spatial_coordinates)[0]

    def get_coordinate_symbol_assignment(name):
        for idx_field in index_fields:
            assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
            data_type = idx_field.dtype
            if data_type.has_element(name):
                rhs = idx_field[0](name)
                lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
127
128
129
                return SympyAssignment(lhs, rhs)
        raise ValueError("Index %s not found in any of the passed index fields" % (name,))

Martin Bauer's avatar
Martin Bauer committed
130
131
132
133
    coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
                                     for n in coordinate_names[:spatial_coordinates]]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
    assignments = coordinate_symbol_assignments + assignments
134
135

    # make 1D loop over index fields
Martin Bauer's avatar
Martin Bauer committed
136
137
    loop_body = Block([])
    loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0])
138
139

    for assignment in assignments:
Martin Bauer's avatar
Martin Bauer committed
140
        loop_body.append(assignment)
141

Martin Bauer's avatar
Martin Bauer committed
142
    function_body = Block([loop_node])
143
144
    ast_node = KernelFunction(function_body, "cpu", "c", make_python_function,
                              ghost_layers=None, function_name=function_name)
145

Martin Bauer's avatar
Martin Bauer committed
146
    fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}
147

Martin Bauer's avatar
Martin Bauer committed
148
149
150
151
    read_only_fields = set([f.name for f in fields_read - fields_written])
    resolve_field_accesses(ast_node, read_only_fields, field_to_fixed_coordinates=fixed_coordinate_mapping)
    move_constants_before_loop(ast_node)
    return ast_node
152

Martin Bauer's avatar
Martin Bauer committed
153

154
def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, assume_single_outer_loop=True):
Martin Bauer's avatar
Martin Bauer committed
155
156
157
158
159
160
    """Parallelize the outer loop with OpenMP.

    Args:
        ast_node: abstract syntax tree created e.g. by :func:`create_kernel`
        schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic'
        num_threads: explicitly specify number of threads
Martin Bauer's avatar
Martin Bauer committed
161
        collapse: number of nested loops to include in parallel region (see OpenMP collapse)
162
163
        assume_single_outer_loop: if True an exception is raised if multiple outer loops are detected for all but
                                  optimized staggered kernels the single-outer-loop assumption should be true
Martin Bauer's avatar
Martin Bauer committed
164
    """
Martin Bauer's avatar
Martin Bauer committed
165
    if not num_threads:
166
167
        return

Martin Bauer's avatar
Martin Bauer committed
168
169
170
171
172
    assert type(ast_node) is ast.KernelFunction
    body = ast_node.body
    threads_clause = "" if num_threads and isinstance(num_threads, bool) else " num_threads(%s)" % (num_threads,)
    wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
    body.append(wrapper_block)
Martin Bauer's avatar
Martin Bauer committed
173

174
175
    outer_loops = [l for l in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment)
                   if l.is_outermost_loop]
Martin Bauer's avatar
Martin Bauer committed
176
    assert outer_loops, "No outer loop found"
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    if assume_single_outer_loop and len(outer_loops) > 1:
        raise ValueError("More than one outer loop found, only one outer loop expected")

    for loop_to_parallelize in outer_loops:
        try:
            loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
        except TypeError:
            loop_range = None

        if num_threads is None:
            import multiprocessing
            num_threads = multiprocessing.cpu_count()

        if loop_range is not None and loop_range < num_threads and not collapse:
            contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
            if len(contained_loops) == 1:
                contained_loop = contained_loops[0]
                try:
                    contained_loop_range = int(contained_loop.stop - contained_loop.start)
                    if contained_loop_range > loop_range:
                        loop_to_parallelize = contained_loop
                except TypeError:
                    pass

        prefix = "#pragma omp for schedule(%s)" % (schedule,)
        if collapse:
            prefix += " collapse(%d)" % (collapse, )
        loop_to_parallelize.prefix_lines.append(prefix)