kernelcreation.py 10.2 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
from functools import partial
3
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.transformations import resolve_buffer_accesses, resolve_field_accesses, make_loop_over_domain, \
5
    add_types, get_optimal_loop_ordering, parse_base_pointer_info, move_constants_before_loop, \
Martin Bauer's avatar
Martin Bauer committed
6
    split_inner_loop, substitute_array_accesses_with_constants, get_base_buffer_index
Martin Bauer's avatar
Martin Bauer committed
7
from pystencils.data_types import TypedSymbol, BasicType, StructType, create_type
8
from pystencils.field import Field, FieldType
Martin Bauer's avatar
Martin Bauer committed
9
import pystencils.astnodes as ast
Martin Bauer's avatar
Martin Bauer committed
10
11
12
from pystencils.cpu.cpujit import make_python_function
from pystencils.assignment import Assignment
from typing import List, Union
Martin Bauer's avatar
Martin Bauer committed
13

Martin Bauer's avatar
Martin Bauer committed
14
AssignmentOrAstNodeList = List[Union[Assignment, ast.Node]]
Martin Bauer's avatar
Martin Bauer committed
15

Martin Bauer's avatar
Martin Bauer committed
16
17

def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "kernel", type_info='double',
18
19
                  split_groups=(), iteration_slice=None, ghost_layers=None,
                  skip_independence_check=False) -> KernelFunction:
Martin Bauer's avatar
Martin Bauer committed
20
    """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
Martin Bauer's avatar
Martin Bauer committed
21
22
23

    Loops are created according to the field accesses in the equations.

Martin Bauer's avatar
Martin Bauer committed
24
25
26
27
28
29
30
31
32
33
34
35
36
    Args:
        assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
        Defining the update rules of the kernel
        function_name: name of the generated function - only important if generated code is written out
        type_info: a map from symbol name to a C type specifier. If not specified all symbols are assumed to
                   be of type 'double' except symbols which occur on the left hand side of equations where the
                   right hand side is a sympy Boolean which are assumed to be 'bool' .
        split_groups: Specification on how to split up inner loop into multiple loops. For details see
                      transformation :func:`pystencils.transformation.split_inner_loop`
        iteration_slice: if not None, iteration is done only over this slice of the field
        ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
                     if None, the number of ghost layers is determined automatically and assumed to be equal for a
                     all dimensions
37
38
        skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
                                 periodicity kernel, that access the field outside the iteration bounds. Use with care!
Martin Bauer's avatar
Martin Bauer committed
39
40
41

    Returns:
        AST node representing a function, that can be printed as C or CUDA code
Martin Bauer's avatar
Martin Bauer committed
42
    """
43

44
    def type_symbol(term):
Martin Bauer's avatar
Martin Bauer committed
45
46
47
        if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
            return term
        elif isinstance(term, sp.Symbol):
Martin Bauer's avatar
Martin Bauer committed
48
49
            if not hasattr(type_info, '__getitem__'):
                return TypedSymbol(term.name, create_type(type_info))
50
            else:
Martin Bauer's avatar
Martin Bauer committed
51
                return TypedSymbol(term.name, type_info[term.name])
Martin Bauer's avatar
Martin Bauer committed
52
53
54
        else:
            raise ValueError("Term has to be field access or symbol")

55
    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
Martin Bauer's avatar
Martin Bauer committed
56
57
    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])
Martin Bauer's avatar
Martin Bauer committed
58

Martin Bauer's avatar
Martin Bauer committed
59
    buffers = set([f for f in all_fields if FieldType.is_buffer(f)])
Martin Bauer's avatar
Martin Bauer committed
60
    fields_without_buffers = all_fields - buffers
61

Martin Bauer's avatar
Martin Bauer committed
62
    body = ast.Block(assignments)
Martin Bauer's avatar
Martin Bauer committed
63
    loop_order = get_optimal_loop_ordering(fields_without_buffers)
Martin Bauer's avatar
Martin Bauer committed
64
    ast_node = make_loop_over_domain(body, function_name, iteration_slice=iteration_slice,
Martin Bauer's avatar
Martin Bauer committed
65
                                                          ghost_layers=ghost_layers, loop_order=loop_order)
Martin Bauer's avatar
Martin Bauer committed
66
    ast_node.target = 'cpu'
Martin Bauer's avatar
Martin Bauer committed
67

Martin Bauer's avatar
Martin Bauer committed
68
    if split_groups:
Martin Bauer's avatar
Martin Bauer committed
69
        typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
Martin Bauer's avatar
Martin Bauer committed
70
        split_inner_loop(ast_node, typed_split_groups)
Martin Bauer's avatar
Martin Bauer committed
71

Martin Bauer's avatar
Martin Bauer committed
72
    base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
Martin Bauer's avatar
Martin Bauer committed
73
74
    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
                                                             field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
75
                         for field in fields_without_buffers}
76

Martin Bauer's avatar
Martin Bauer committed
77
78
    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
                                                                    field.spatial_dimensions, field.index_dimensions)
Martin Bauer's avatar
Martin Bauer committed
79
80
                                for field in buffers}
    base_pointer_info.update(buffer_base_pointer_info)
81

Martin Bauer's avatar
Martin Bauer committed
82
83
84
85
86
87
88
    if any(FieldType.is_buffer(f) for f in all_fields):
        resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
    resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
    substitute_array_accesses_with_constants(ast_node)
    move_constants_before_loop(ast_node)
    ast_node.compile = partial(make_python_function, ast_node)
    return ast_node
Martin Bauer's avatar
Martin Bauer committed
89
90


Martin Bauer's avatar
Martin Bauer committed
91
92
def create_indexed_kernel(assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel",
                          type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction:
93
    """
Martin Bauer's avatar
Martin Bauer committed
94
    Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
95
96
    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.

Martin Bauer's avatar
Martin Bauer committed
97
    The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type.
98
    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
Martin Bauer's avatar
Martin Bauer committed
99
    'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
100
101
    example boundary parameters.

Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
106
107
    Args:
        assignments: list of assignments
        index_fields: list of index fields, i.e. 1D fields with struct data type
        type_info: see documentation of :func:`create_kernel`
        function_name: see documentation of :func:`create_kernel`
        coordinate_names: name of the coordinate fields in the struct data type
108
    """
109
    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
Martin Bauer's avatar
Martin Bauer committed
110
111
112
    all_fields = fields_read.union(fields_written)

    for index_field in index_fields:
Martin Bauer's avatar
Martin Bauer committed
113
        index_field.field_type = FieldType.INDEXED
Martin Bauer's avatar
Martin Bauer committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"

    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatial_coordinates = list(spatial_coordinates)[0]

    def get_coordinate_symbol_assignment(name):
        for idx_field in index_fields:
            assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
            data_type = idx_field.dtype
            if data_type.has_element(name):
                rhs = idx_field[0](name)
                lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
129
130
131
                return SympyAssignment(lhs, rhs)
        raise ValueError("Index %s not found in any of the passed index fields" % (name,))

Martin Bauer's avatar
Martin Bauer committed
132
133
134
135
    coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
                                     for n in coordinate_names[:spatial_coordinates]]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
    assignments = coordinate_symbol_assignments + assignments
136
137

    # make 1D loop over index fields
Martin Bauer's avatar
Martin Bauer committed
138
139
    loop_body = Block([])
    loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0])
140
141

    for assignment in assignments:
Martin Bauer's avatar
Martin Bauer committed
142
        loop_body.append(assignment)
143

Martin Bauer's avatar
Martin Bauer committed
144
145
    function_body = Block([loop_node])
    ast_node = KernelFunction(function_body, backend="cpu", function_name=function_name)
146

Martin Bauer's avatar
Martin Bauer committed
147
    fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}
148

Martin Bauer's avatar
Martin Bauer committed
149
150
151
152
153
154
    read_only_fields = set([f.name for f in fields_read - fields_written])
    resolve_field_accesses(ast_node, read_only_fields, field_to_fixed_coordinates=fixed_coordinate_mapping)
    substitute_array_accesses_with_constants(ast_node)
    move_constants_before_loop(ast_node)
    ast_node.compile = partial(make_python_function, ast_node)
    return ast_node
155

Martin Bauer's avatar
Martin Bauer committed
156

Martin Bauer's avatar
Martin Bauer committed
157
158
159
160
161
162
163
def add_openmp(ast_node, schedule="static", num_threads=True):
    """Parallelize the outer loop with OpenMP.

    Args:
        ast_node: abstract syntax tree created e.g. by :func:`create_kernel`
        schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic'
        num_threads: explicitly specify number of threads
Martin Bauer's avatar
Martin Bauer committed
164
    """
Martin Bauer's avatar
Martin Bauer committed
165
    if not num_threads:
166
167
        return

Martin Bauer's avatar
Martin Bauer committed
168
169
170
171
172
    assert type(ast_node) is ast.KernelFunction
    body = ast_node.body
    threads_clause = "" if num_threads and isinstance(num_threads, bool) else " num_threads(%s)" % (num_threads,)
    wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
    body.append(wrapper_block)
Martin Bauer's avatar
Martin Bauer committed
173

Martin Bauer's avatar
Martin Bauer committed
174
175
    outer_loops = [l for l in body.atoms(ast.LoopOverCoordinate) if l.is_outermost_loop]
    assert outer_loops, "No outer loop found"
Martin Bauer's avatar
Martin Bauer committed
176
    assert len(outer_loops) <= 1, "More than one outer loop found. Not clear where to put OpenMP pragma."
Martin Bauer's avatar
Martin Bauer committed
177
    loop_to_parallelize = outer_loops[0]
178
    try:
Martin Bauer's avatar
Martin Bauer committed
179
        loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
180
    except TypeError:
Martin Bauer's avatar
Martin Bauer committed
181
        loop_range = None
182

Martin Bauer's avatar
Martin Bauer committed
183
    if num_threads is None:
184
        import multiprocessing
Martin Bauer's avatar
Martin Bauer committed
185
        num_threads = multiprocessing.cpu_count()
186

Martin Bauer's avatar
Martin Bauer committed
187
188
189
190
    if loop_range is not None and loop_range < num_threads:
        contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
        if len(contained_loops) == 1:
            contained_loop = contained_loops[0]
191
            try:
Martin Bauer's avatar
Martin Bauer committed
192
193
194
                contained_loop_range = int(contained_loop.stop - contained_loop.start)
                if contained_loop_range > loop_range:
                    loop_to_parallelize = contained_loop
195
196
197
            except TypeError:
                pass

Martin Bauer's avatar
Martin Bauer committed
198
    loop_to_parallelize.prefix_lines.append("#pragma omp for schedule(%s)" % (schedule,))