vectorization.py 12.8 KB
Newer Older
1
import warnings
Martin Bauer's avatar
Martin Bauer committed
2
3
4
from typing import Container, Union

import sympy as sp
5
from sympy.logic.boolalg import BooleanFunction
Martin Bauer's avatar
Martin Bauer committed
6

7
import pystencils.astnodes as ast
Martin Bauer's avatar
Martin Bauer committed
8
9
from pystencils.backends.simd_instruction_sets import get_vector_instruction_set
from pystencils.data_types import (
10
    PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access)
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
13
14
from pystencils.integer_functions import modulo_ceil, modulo_floor
from pystencils.sympyextensions import fast_subs
15
from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one
Martin Bauer's avatar
Martin Bauer committed
16
17


18
19
# noinspection PyPep8Naming
class vec_any(sp.Function):
20
    nargs = (1,)
21
22
23
24


# noinspection PyPep8Naming
class vec_all(sp.Function):
25
    nargs = (1,)
26
27


28
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
29
              assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False,
30
              assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True):
Martin Bauer's avatar
Martin Bauer committed
31
32
33
34
    """Explicit vectorization using SIMD vectorization via intrinsics.

    Args:
        kernel_ast: abstract syntax tree (KernelFunction node)
35
        instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
Martin Bauer's avatar
Martin Bauer committed
36
37
38
39
40
41
42
        assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
                        used. If true, some of the loads are assumed to be from aligned memory addresses.
                        For example if x is the fastest coordinate, the access to center can be fetched via an
                        aligned-load instruction, for the west or east accesses potentially slower unaligend-load
                        instructions have to be used.
        nontemporal: a container of fields or field names for which nontemporal (streaming) stores are used.
                     If true, nontemporal access instructions are used for all fields.
43
44
        assume_inner_stride_one: kernels with non-constant inner loop bound and strides can not be vectorized since
                                 the inner loop stride is a runtime variable and thus might not be always 1.
45
                                 If this parameter is set to true, the inner stride is assumed to be always one.
46
                                 This has to be ensured at runtime!
47
48
49
50
51
        assume_sufficient_line_padding: if True and assume_inner_stride_one, no tail loop is created but loop is
                                        extended by at most (vector_width-1) elements
                                        assumes that at the end of each line there is enough padding with dummy data
                                        depending on the access pattern there might be additional padding
                                        required at the end of the array
Martin Bauer's avatar
Martin Bauer committed
52
    """
53
54
    if instruction_set is None:
        return
55

Martin Bauer's avatar
Martin Bauer committed
56
57
58
59
60
61
    all_fields = kernel_ast.fields_accessed
    if nontemporal is None or nontemporal is False:
        nontemporal = {}
    elif nontemporal is True:
        nontemporal = all_fields

62
63
64
    if assume_inner_stride_one:
        replace_inner_stride_with_one(kernel_ast)

65
    field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
Martin Bauer's avatar
Martin Bauer committed
66
67
68
69
70
71
    if len(field_float_dtypes) != 1:
        raise NotImplementedError("Cannot vectorize kernels that contain accesses "
                                  "to differently typed floating point fields")
    float_size = field_float_dtypes.pop().numpy_dtype.itemsize
    assert float_size in (8, 4)
    vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
72
                                           instruction_set=instruction_set)
Martin Bauer's avatar
Martin Bauer committed
73
74
75
    vector_width = vector_is['width']
    kernel_ast.instruction_set = vector_is

76
77
    vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned,
                                                nontemporal, assume_sufficient_line_padding)
Martin Bauer's avatar
Martin Bauer committed
78
79
80
    insert_vector_casts(kernel_ast)


81
82
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
                                                assume_sufficient_line_padding):
Martin Bauer's avatar
Martin Bauer committed
83
    """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
Martin Bauer's avatar
Martin Bauer committed
84
85
    all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
    inner_loops = [n for n in all_loops if n.is_innermost_loop]
Markus Holzer's avatar
Markus Holzer committed
86
    zero_loop_counters = {lo.loop_counter_symbol: 0 for lo in all_loops}
87

Martin Bauer's avatar
Martin Bauer committed
88
89
    for loop_node in inner_loops:
        loop_range = loop_node.stop - loop_node.start
90

Martin Bauer's avatar
Martin Bauer committed
91
        # cut off loop tail, that is not a multiple of four
92
93
94
95
96
97
        if assume_aligned and assume_sufficient_line_padding:
            loop_range = loop_node.stop - loop_node.start
            new_stop = loop_node.start + modulo_ceil(loop_range, vector_width)
            loop_node.stop = new_stop
        else:
            cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
Markus Holzer's avatar
Markus Holzer committed
98
99
            loop_nodes = [lo for lo in cut_loop(loop_node,
                                                [cutting_point]).args if isinstance(lo, ast.LoopOverCoordinate)]
Nils Kohl's avatar
Nils Kohl committed
100
101
102
            assert len(loop_nodes) in (0, 1, 2)  # 2 for main and tail loop, 1 if loop range divisible by vector width
            if len(loop_nodes) == 0:
                continue
103
            loop_node = loop_nodes[0]
104

105
        # Find all array accesses (indexed) that depend on the loop counter as offset
Martin Bauer's avatar
Martin Bauer committed
106
        loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
107
108
        substitutions = {}
        successful = True
Martin Bauer's avatar
Martin Bauer committed
109
        for indexed in loop_node.atoms(sp.Indexed):
110
            base, index = indexed.args
Martin Bauer's avatar
Martin Bauer committed
111
112
            if loop_counter_symbol in index.atoms(sp.Symbol):
                loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
Martin Bauer's avatar
Martin Bauer committed
113
                aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
Martin Bauer's avatar
Martin Bauer committed
114
                if not loop_counter_is_offset:
115
116
                    successful = False
                    break
Martin Bauer's avatar
Martin Bauer committed
117
                typed_symbol = base.label
118
119
                assert type(typed_symbol.dtype) is PointerType, \
                    "Type of access is {}, {}".format(typed_symbol.dtype, indexed)
Martin Bauer's avatar
Martin Bauer committed
120
121
122
123
124
125

                vec_type = VectorType(typed_symbol.dtype.base_type, vector_width)
                use_aligned_access = aligned_access and assume_aligned
                nontemporal = False
                if hasattr(indexed, 'field'):
                    nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
Martin Bauer's avatar
Martin Bauer committed
126
                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True)
127
128
129
        if not successful:
            warnings.warn("Could not vectorize loop because of non-consecutive memory access")
            continue
130

Martin Bauer's avatar
Martin Bauer committed
131
132
        loop_node.step = vector_width
        loop_node.subs(substitutions)
133
134
135
136
137
        vector_loop_counter = cast_func(tuple(loop_counter_symbol + i for i in range(vector_width)),
                                        VectorType(loop_counter_symbol.dtype, vector_width))

        fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
138

Martin Bauer's avatar
Martin Bauer committed
139
140
141
142
143
144
145
        mask_conditionals(loop_node)


def mask_conditionals(loop_body):

    def visit_node(node, mask):
        if isinstance(node, ast.Conditional):
146
            cond = node.condition_expr
147
148
149
            skip = (loop_body.loop_counter_symbol not in cond.atoms(sp.Symbol)) or cond.func in (vec_all, vec_any)
            cond = True if skip else cond

150
            true_mask = sp.And(cond, mask)
Martin Bauer's avatar
Martin Bauer committed
151
152
153
154
            visit_node(node.true_block, true_mask)
            if node.false_block:
                false_mask = sp.And(sp.Not(node.condition_expr), mask)
                visit_node(node, false_mask)
155
156
            if not skip:
                node.condition_expr = vec_any(node.condition_expr)
Martin Bauer's avatar
Martin Bauer committed
157
158
159
160
161
162
163
164
165
166
167
        elif isinstance(node, ast.SympyAssignment):
            if mask is not True:
                s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4]))
                     for ma in node.atoms(vector_memory_access)}
                node.subs(s)
        else:
            for arg in node.args:
                visit_node(arg, mask)

    visit_node(loop_body, mask=True)

168

Martin Bauer's avatar
Martin Bauer committed
169
170
171
def insert_vector_casts(ast_node):
    """Inserts necessary casts from scalar values to vector values."""

172
    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
173

Martin Bauer's avatar
Martin Bauer committed
174
    def visit_expr(expr):
Martin Bauer's avatar
Martin Bauer committed
175
176
177
178
        if isinstance(expr, vector_memory_access):
            return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3],
                                        visit_expr(expr.args[4]))
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
179
            return expr
180
181
182
183
        elif expr.func is sp.Abs:
            new_arg = visit_expr(expr.args[0])
            pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True))
            return visit_expr(pw)
184
        elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
Martin Bauer's avatar
Martin Bauer committed
185
186
187
            new_args = [visit_expr(a) for a in expr.args]
            arg_types = [get_type_of_expression(a) for a in new_args]
            if not any(type(t) is VectorType for t in arg_types):
188
189
                return expr
            else:
Martin Bauer's avatar
Martin Bauer committed
190
                target_type = collate_types(arg_types)
Martin Bauer's avatar
Martin Bauer committed
191
                casted_args = [cast_func(a, target_type) if t != target_type else a
Martin Bauer's avatar
Martin Bauer committed
192
193
                               for a, t in zip(new_args, arg_types)]
                return expr.func(*casted_args)
194
        elif expr.func is sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
195
196
            new_arg = visit_expr(expr.args[0])
            return expr.func(new_arg, expr.args[1])
197
        elif expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
198
199
200
201
            new_results = [visit_expr(a[0]) for a in expr.args]
            new_conditions = [visit_expr(a[1]) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [get_type_of_expression(a) for a in new_conditions]
202

Martin Bauer's avatar
Martin Bauer committed
203
204
205
206
            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType:
                result_target_type = VectorType(result_target_type, width=condition_target_type.width)
207
208
            if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
                condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
209

Martin Bauer's avatar
Martin Bauer committed
210
            casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
Martin Bauer's avatar
Martin Bauer committed
211
                              for a, t in zip(new_results, types_of_results)]
212

Martin Bauer's avatar
Martin Bauer committed
213
            casted_conditions = [cast_func(a, condition_target_type)
Martin Bauer's avatar
Martin Bauer committed
214
215
                                 if t != condition_target_type and a is not True else a
                                 for a, t in zip(new_conditions, types_of_conditions)]
216

Martin Bauer's avatar
Martin Bauer committed
217
            return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
218
219
220
        else:
            return expr

Martin Bauer's avatar
Martin Bauer committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def visit_node(node, substitution_dict):
        substitution_dict = substitution_dict.copy()
        for arg in node.args:
            if isinstance(arg, ast.SympyAssignment):
                assignment = arg
                subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                      skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                assignment.rhs = visit_expr(subs_expr)
                rhs_type = get_type_of_expression(assignment.rhs)
                if isinstance(assignment.lhs, TypedSymbol):
                    lhs_type = assignment.lhs.dtype
                    if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
                        new_lhs_type = VectorType(lhs_type, rhs_type.width)
                        new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                        substitution_dict[assignment.lhs] = new_lhs
                        assignment.lhs = new_lhs
Martin Bauer's avatar
Martin Bauer committed
237
238
                elif isinstance(assignment.lhs, vector_memory_access):
                    assignment.lhs = visit_expr(assignment.lhs)
239
240
241
242
243
            elif isinstance(arg, ast.Conditional):
                arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
                                               skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                arg.condition_expr = visit_expr(arg.condition_expr)
                visit_node(arg, substitution_dict)
Martin Bauer's avatar
Martin Bauer committed
244
245
246
247
            else:
                visit_node(arg, substitution_dict)

    visit_node(ast_node, {})