kernelcreation.py 8.16 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
2
3
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.data_types import BasicType, StructType, TypedSymbol
from pystencils.field import Field, FieldType
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.gpucuda.cudajit import make_python_function
Martin Bauer's avatar
Martin Bauer committed
5
6
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import (
7
8
9
10
11
12
13
14
15
16
17
18
    add_types, get_base_buffer_index, get_common_shape, implement_interpolations,
    parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)


def create_cuda_kernel(assignments,
                       function_name="kernel",
                       type_info=None,
                       indexing_creator=BlockIndexing,
                       iteration_slice=None,
                       ghost_layers=None,
                       skip_independence_check=False,
                       use_textures_for_interpolation=True):
19
    assert assignments, "Assignments must not be empty!"
20
    fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
Martin Bauer's avatar
Martin Bauer committed
21
22
    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])
23

24
    buffers = set([f for f in all_fields if FieldType.is_buffer(f) or FieldType.is_custom(f)])
Martin Bauer's avatar
Martin Bauer committed
25
    fields_without_buffers = all_fields - buffers
26

Martin Bauer's avatar
Martin Bauer committed
27
28
29
30
    field_accesses = set()
    num_buffer_accesses = 0
    for eq in assignments:
        field_accesses.update(eq.atoms(Field.Access))
31
        field_accesses = {e for e in field_accesses if not e.is_absolute_access}
32
        num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field))
33

Martin Bauer's avatar
Martin Bauer committed
34
    common_shape = get_common_shape(fields_without_buffers)
35

Martin Bauer's avatar
Martin Bauer committed
36
    if iteration_slice is None:
37
        # determine iteration slice from ghost layers
Martin Bauer's avatar
Martin Bauer committed
38
        if ghost_layers is None:
39
            # determine required number of ghost layers from field access
Martin Bauer's avatar
Martin Bauer committed
40
41
42
43
44
45
            required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
            ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(common_shape)
        iteration_slice = []
        if isinstance(ghost_layers, int):
            for i in range(len(common_shape)):
                iteration_slice.append(slice(ghost_layers, -ghost_layers if ghost_layers > 0 else None))
46
            ghost_layers = [(ghost_layers, ghost_layers)] * len(common_shape)
47
        else:
Martin Bauer's avatar
Martin Bauer committed
48
            for i in range(len(common_shape)):
Martin Bauer's avatar
Martin Bauer committed
49
50
                iteration_slice.append(slice(ghost_layers[i][0],
                                             -ghost_layers[i][1] if ghost_layers[i][1] > 0 else None))
51

Martin Bauer's avatar
Martin Bauer committed
52
    indexing = indexing_creator(field=list(fields_without_buffers)[0], iteration_slice=iteration_slice)
53
54
55
56
57
58
    coord_mapping = indexing.coordinates

    cell_idx_assignments = [SympyAssignment(LoopOverCoordinate.get_loop_counter_symbol(i), value)
                            for i, value in enumerate(coord_mapping)]
    cell_idx_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i, _ in enumerate(coord_mapping)]
    assignments = cell_idx_assignments + assignments
Martin Bauer's avatar
Martin Bauer committed
59

60
    block = Block(assignments)
61

Martin Bauer's avatar
Martin Bauer committed
62
    block = indexing.guard(block, common_shape)
63
64
    unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)

65
    ast = KernelFunction(block, 'gpu', 'gpucuda', make_python_function, ghost_layers, function_name)
Martin Bauer's avatar
Martin Bauer committed
66
    ast.global_variables.update(indexing.index_variables)
67

68
69
    implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)

Martin Bauer's avatar
Martin Bauer committed
70
71
72
73
    base_pointer_spec = [['spatialInner0']]
    base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
                                                         f.spatial_dimensions, f.index_dimensions)
                         for f in all_fields}
74

75
    coord_mapping = {f.name: cell_idx_symbols for f in all_fields}
76

Martin Bauer's avatar
Martin Bauer committed
77
    loop_strides = list(fields_without_buffers)[0].shape
78

Martin Bauer's avatar
Martin Bauer committed
79
80
    if any(FieldType.is_buffer(f) for f in all_fields):
        resolve_buffer_accesses(ast, get_base_buffer_index(ast, indexing.coordinates, loop_strides), read_only_fields)
81

Martin Bauer's avatar
Martin Bauer committed
82
    resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
Martin Bauer's avatar
Martin Bauer committed
83
                           field_to_fixed_coordinates=coord_mapping)
84

Martin Bauer's avatar
Martin Bauer committed
85
86
    # add the function which determines #blocks and #threads as additional member to KernelFunction node
    # this is used by the jit
87
88
89

    # If loop counter symbols have been explicitly used in the update equations (e.g. for built in periodicity),
    # they are defined here
Martin Bauer's avatar
Martin Bauer committed
90
91
    undefined_loop_counters = {LoopOverCoordinate.is_loop_counter_symbol(s): s for s in ast.body.undefined_symbols
                               if LoopOverCoordinate.is_loop_counter_symbol(s) is not None}
Martin Bauer's avatar
Martin Bauer committed
92
93
    for i, loop_counter in undefined_loop_counters.items():
        ast.body.insert_front(SympyAssignment(loop_counter, indexing.coordinates[i]))
94

95
    ast.indexing = indexing
96
97
98
    return ast


99
100
101
102
103
104
105
def created_indexed_cuda_kernel(assignments,
                                index_fields,
                                function_name="kernel",
                                type_info=None,
                                coordinate_names=('x', 'y', 'z'),
                                indexing_creator=BlockIndexing,
                                use_textures_for_interpolation=True):
106
    fields_read, fields_written, assignments = add_types(assignments, type_info, check_independence_condition=False)
Martin Bauer's avatar
Martin Bauer committed
107
108
    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])
109

Martin Bauer's avatar
Martin Bauer committed
110
111
112
113
    for index_field in index_fields:
        index_field.field_type = FieldType.INDEXED
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
Martin Bauer's avatar
Martin Bauer committed
114
115
116
117
118
119
120

    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatial_coordinates = list(spatial_coordinates)[0]

    def get_coordinate_symbol_assignment(name):
121
122
123
        for ind_f in index_fields:
            assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type"
            data_type = ind_f.dtype
Martin Bauer's avatar
Martin Bauer committed
124
            if data_type.has_element(name):
125
                rhs = ind_f[0](name)
Martin Bauer's avatar
Martin Bauer committed
126
                lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name)))
127
128
129
                return SympyAssignment(lhs, rhs)
        raise ValueError("Index %s not found in any of the passed index fields" % (name,))

Martin Bauer's avatar
Martin Bauer committed
130
131
132
    coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
                                     for n in coordinate_names[:spatial_coordinates]]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
133

Martin Bauer's avatar
Martin Bauer committed
134
135
136
    idx_field = list(index_fields)[0]
    indexing = indexing_creator(field=idx_field,
                                iteration_slice=[slice(None, None, None)] * len(idx_field.spatial_shape))
137

Martin Bauer's avatar
Martin Bauer committed
138
139
    function_body = Block(coordinate_symbol_assignments + assignments)
    function_body = indexing.guard(function_body, get_common_shape(index_fields))
Martin Bauer's avatar
Martin Bauer committed
140
    ast = KernelFunction(function_body, 'gpu', 'gpucuda', make_python_function, None, function_name)
Martin Bauer's avatar
Martin Bauer committed
141
    ast.global_variables.update(indexing.index_variables)
142

143
144
    implement_interpolations(ast, implement_by_texture_accesses=use_textures_for_interpolation)

Martin Bauer's avatar
Martin Bauer committed
145
    coord_mapping = indexing.coordinates
Martin Bauer's avatar
Martin Bauer committed
146
147
148
149
    base_pointer_spec = [['spatialInner0']]
    base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
                                                         f.spatial_dimensions, f.index_dimensions)
                         for f in all_fields}
150

Martin Bauer's avatar
Martin Bauer committed
151
152
153
    coord_mapping = {f.name: coord_mapping for f in index_fields}
    coord_mapping.update({f.name: coordinate_typed_symbols for f in non_index_fields})
    resolve_field_accesses(ast, read_only_fields, field_to_fixed_coordinates=coord_mapping,
Martin Bauer's avatar
Martin Bauer committed
154
                           field_to_base_pointer_info=base_pointer_info)
155

156
157
    # add the function which determines #blocks and #threads as additional member to KernelFunction node
    # this is used by the jit
158
159
    ast.indexing = indexing
    return ast