vectorization.py 16.7 KB
Newer Older
1
import warnings
Martin Bauer's avatar
Martin Bauer committed
2
3
from typing import Container, Union

4
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
5
import sympy as sp
6
from sympy.logic.boolalg import BooleanFunction
Martin Bauer's avatar
Martin Bauer committed
7

8
import pystencils.astnodes as ast
Michael Kuron's avatar
Michael Kuron committed
9
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.data_types import (
11
    PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access)
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
14
15
from pystencils.integer_functions import modulo_ceil, modulo_floor
from pystencils.sympyextensions import fast_subs
16
from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one
Martin Bauer's avatar
Martin Bauer committed
17
18


19
20
# noinspection PyPep8Naming
class vec_any(sp.Function):
21
    nargs = (1,)
22
23
24
25


# noinspection PyPep8Naming
class vec_all(sp.Function):
26
    nargs = (1,)
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class NontemporalFence(ast.Node):
    def __init__(self):
        super(NontemporalFence, self).__init__(parent=None)

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    @property
    def args(self):
        return []

    def __eq__(self, other):
        return isinstance(other, NontemporalFence)


class CachelineSize(ast.Node):
    symbol = sp.Symbol("_clsize")
    mask_symbol = sp.Symbol("_clsize_mask")
    last_symbol = sp.Symbol("_cl_lastvec")
    
    def __init__(self):
        super(CachelineSize, self).__init__(parent=None)

    @property
    def symbols_defined(self):
Markus Holzer's avatar
Markus Holzer committed
59
        return {self.symbol, self.mask_symbol, self.last_symbol}
60
61
62
63
64
65
66
67
68
69
70
71

    @property
    def undefined_symbols(self):
        return set()

    @property
    def args(self):
        return []

    def __eq__(self, other):
        return isinstance(other, CachelineSize)

Michael Kuron's avatar
Michael Kuron committed
72
73
74
    def __hash__(self):
        return hash(self.symbol)

75

Michael Kuron's avatar
Michael Kuron committed
76
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
77
              assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False,
78
              assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True):
Martin Bauer's avatar
Martin Bauer committed
79
80
81
82
    """Explicit vectorization using SIMD vectorization via intrinsics.

    Args:
        kernel_ast: abstract syntax tree (KernelFunction node)
83
        instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
Martin Bauer's avatar
Martin Bauer committed
84
85
86
87
88
89
90
        assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
                        used. If true, some of the loads are assumed to be from aligned memory addresses.
                        For example if x is the fastest coordinate, the access to center can be fetched via an
                        aligned-load instruction, for the west or east accesses potentially slower unaligend-load
                        instructions have to be used.
        nontemporal: a container of fields or field names for which nontemporal (streaming) stores are used.
                     If true, nontemporal access instructions are used for all fields.
91
92
        assume_inner_stride_one: kernels with non-constant inner loop bound and strides can not be vectorized since
                                 the inner loop stride is a runtime variable and thus might not be always 1.
93
                                 If this parameter is set to true, the inner stride is assumed to be always one.
94
                                 This has to be ensured at runtime!
95
96
97
98
99
        assume_sufficient_line_padding: if True and assume_inner_stride_one, no tail loop is created but loop is
                                        extended by at most (vector_width-1) elements
                                        assumes that at the end of each line there is enough padding with dummy data
                                        depending on the access pattern there might be additional padding
                                        required at the end of the array
Martin Bauer's avatar
Martin Bauer committed
100
    """
Michael Kuron's avatar
Michael Kuron committed
101
102
103
104
105
    if instruction_set == 'best':
        if get_supported_instruction_sets():
            instruction_set = get_supported_instruction_sets()[-1]
        else:
            instruction_set = 'avx'
106
107
    if instruction_set is None:
        return
108

Martin Bauer's avatar
Martin Bauer committed
109
110
111
112
113
114
    all_fields = kernel_ast.fields_accessed
    if nontemporal is None or nontemporal is False:
        nontemporal = {}
    elif nontemporal is True:
        nontemporal = all_fields

115
116
117
    if assume_inner_stride_one:
        replace_inner_stride_with_one(kernel_ast)

118
    field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
Martin Bauer's avatar
Martin Bauer committed
119
120
121
122
123
124
    if len(field_float_dtypes) != 1:
        raise NotImplementedError("Cannot vectorize kernels that contain accesses "
                                  "to differently typed floating point fields")
    float_size = field_float_dtypes.pop().numpy_dtype.itemsize
    assert float_size in (8, 4)
    vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
125
                                           instruction_set=instruction_set)
Martin Bauer's avatar
Martin Bauer committed
126
127
128
    vector_width = vector_is['width']
    kernel_ast.instruction_set = vector_is

Michael Kuron's avatar
Michael Kuron committed
129
130
    strided = 'storeS' in vector_is and 'loadS' in vector_is
    keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU']
Michael Kuron's avatar
Michael Kuron committed
131
    vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
Michael Kuron's avatar
Michael Kuron committed
132
                                                strided, keep_loop_stop, assume_sufficient_line_padding)
Martin Bauer's avatar
Martin Bauer committed
133
134
135
    insert_vector_casts(kernel_ast)


136
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
Michael Kuron's avatar
Michael Kuron committed
137
                                                strided, keep_loop_stop, assume_sufficient_line_padding):
Martin Bauer's avatar
Martin Bauer committed
138
    """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
Martin Bauer's avatar
Martin Bauer committed
139
140
    all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
    inner_loops = [n for n in all_loops if n.is_innermost_loop]
Markus Holzer's avatar
Markus Holzer committed
141
    zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops}
142

Martin Bauer's avatar
Martin Bauer committed
143
144
    for loop_node in inner_loops:
        loop_range = loop_node.stop - loop_node.start
145

Martin Bauer's avatar
Martin Bauer committed
146
        # cut off loop tail, that is not a multiple of four
Michael Kuron's avatar
Michael Kuron committed
147
148
149
        if keep_loop_stop:
            pass
        elif assume_aligned and assume_sufficient_line_padding:
150
151
152
153
154
            loop_range = loop_node.stop - loop_node.start
            new_stop = loop_node.start + modulo_ceil(loop_range, vector_width)
            loop_node.stop = new_stop
        else:
            cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
Markus Holzer's avatar
Markus Holzer committed
155
            loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)]
Nils Kohl's avatar
Nils Kohl committed
156
157
158
            assert len(loop_nodes) in (0, 1, 2)  # 2 for main and tail loop, 1 if loop range divisible by vector width
            if len(loop_nodes) == 0:
                continue
159
            loop_node = loop_nodes[0]
160

161
        # Find all array accesses (indexed) that depend on the loop counter as offset
Martin Bauer's avatar
Martin Bauer committed
162
        loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
163
164
        substitutions = {}
        successful = True
Martin Bauer's avatar
Martin Bauer committed
165
        for indexed in loop_node.atoms(sp.Indexed):
166
            base, index = indexed.args
Martin Bauer's avatar
Martin Bauer committed
167
168
            if loop_counter_symbol in index.atoms(sp.Symbol):
                loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
Martin Bauer's avatar
Martin Bauer committed
169
                aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
Michael Kuron's avatar
Michael Kuron committed
170
                stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
Michael Kuron's avatar
Michael Kuron committed
171
                if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()):
172
173
                    successful = False
                    break
Martin Bauer's avatar
Martin Bauer committed
174
                typed_symbol = base.label
175
                assert type(typed_symbol.dtype) is PointerType, \
176
                    f"Type of access is {typed_symbol.dtype}, {indexed}"
Martin Bauer's avatar
Martin Bauer committed
177
178
179
180
181
182

                vec_type = VectorType(typed_symbol.dtype.base_type, vector_width)
                use_aligned_access = aligned_access and assume_aligned
                nontemporal = False
                if hasattr(indexed, 'field'):
                    nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
Michael Kuron's avatar
Michael Kuron committed
183
                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
Michael Kuron's avatar
Michael Kuron committed
184
                                                              stride if strided else 1)
185
                if nontemporal:
186
                    # insert NontemporalFence after the outermost loop
187
188
189
                    parent = loop_node.parent
                    while type(parent.parent.parent) is not ast.KernelFunction:
                        parent = parent.parent
190
                    parent.parent.insert_after(NontemporalFence(), parent, if_not_exists=True)
191
                    # insert CachelineSize at the beginning of the kernel
192
                    parent.parent.insert_front(CachelineSize(), if_not_exists=True)
193
194
195
        if not successful:
            warnings.warn("Could not vectorize loop because of non-consecutive memory access")
            continue
196

Martin Bauer's avatar
Martin Bauer committed
197
198
        loop_node.step = vector_width
        loop_node.subs(substitutions)
199
        vector_int_width = ast_node.instruction_set['intwidth']
200
        vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
Michael Kuron's avatar
Michael Kuron committed
201
202
            + cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
                        VectorType(loop_counter_symbol.dtype, vector_int_width))
203
204
205

        fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
206

Martin Bauer's avatar
Martin Bauer committed
207
208
        mask_conditionals(loop_node)

Michael Kuron's avatar
Michael Kuron committed
209
210
211
212
213
214
215
216
217
        from pystencils.rng import RNGBase
        substitutions = {}
        for rng in loop_node.atoms(RNGBase):
            new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
                                  for s in rng.result_symbols]
            substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
            rng._symbols_defined = set(new_result_symbols)
        fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))

Martin Bauer's avatar
Martin Bauer committed
218
219
220
221

def mask_conditionals(loop_body):
    def visit_node(node, mask):
        if isinstance(node, ast.Conditional):
222
            cond = node.condition_expr
223
224
225
            skip = (loop_body.loop_counter_symbol not in cond.atoms(sp.Symbol)) or cond.func in (vec_all, vec_any)
            cond = True if skip else cond

226
            true_mask = sp.And(cond, mask)
Martin Bauer's avatar
Martin Bauer committed
227
228
229
230
            visit_node(node.true_block, true_mask)
            if node.false_block:
                false_mask = sp.And(sp.Not(node.condition_expr), mask)
                visit_node(node, false_mask)
231
232
            if not skip:
                node.condition_expr = vec_any(node.condition_expr)
Martin Bauer's avatar
Martin Bauer committed
233
234
        elif isinstance(node, ast.SympyAssignment):
            if mask is not True:
Michael Kuron's avatar
Michael Kuron committed
235
                s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
Martin Bauer's avatar
Martin Bauer committed
236
237
238
239
240
241
242
243
                     for ma in node.atoms(vector_memory_access)}
                node.subs(s)
        else:
            for arg in node.args:
                visit_node(arg, mask)

    visit_node(loop_body, mask=True)

244

Martin Bauer's avatar
Martin Bauer committed
245
246
247
def insert_vector_casts(ast_node):
    """Inserts necessary casts from scalar values to vector values."""

248
    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
249

Martin Bauer's avatar
Martin Bauer committed
250
    def visit_expr(expr):
Martin Bauer's avatar
Martin Bauer committed
251
        if isinstance(expr, vector_memory_access):
Michael Kuron's avatar
Michael Kuron committed
252
            return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4]), *expr.args[5:])
Martin Bauer's avatar
Martin Bauer committed
253
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
254
            return expr
Markus Holzer's avatar
Markus Holzer committed
255
        elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
256
            new_arg = visit_expr(expr.args[0])
257
258
            base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \
                else get_type_of_expression(expr.args[0])
259
            pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)),
260
                              (new_arg, True))
261
            return visit_expr(pw)
262
        elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
263
264
265
266
            default_type = 'double'
            if expr.func is sp.Mul and expr.args[0] == -1:
                # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                dtype = int
267
268
                for arg in expr.atoms(vector_memory_access):
                    if arg.dtype.base_type.is_float():
269
                        dtype = arg.dtype.base_type.numpy_dtype.type
270
271
                for arg in expr.atoms(TypedSymbol):
                    if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
272
273
274
275
276
                        dtype = arg.dtype.base_type.numpy_dtype.type
                if dtype is not int:
                    if dtype is np.float32:
                        default_type = 'float'
                    expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
Martin Bauer's avatar
Martin Bauer committed
277
            new_args = [visit_expr(a) for a in expr.args]
278
            arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
Martin Bauer's avatar
Martin Bauer committed
279
            if not any(type(t) is VectorType for t in arg_types):
280
281
                return expr
            else:
Martin Bauer's avatar
Martin Bauer committed
282
                target_type = collate_types(arg_types)
283
284
285
                casted_args = [
                    cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a
                    for a, t in zip(new_args, arg_types)]
Martin Bauer's avatar
Martin Bauer committed
286
                return expr.func(*casted_args)
287
        elif expr.func is sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
288
289
            new_arg = visit_expr(expr.args[0])
            return expr.func(new_arg, expr.args[1])
290
        elif expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
291
292
293
294
            new_results = [visit_expr(a[0]) for a in expr.args]
            new_conditions = [visit_expr(a[1]) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [get_type_of_expression(a) for a in new_conditions]
295

Martin Bauer's avatar
Martin Bauer committed
296
297
298
299
            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType:
                result_target_type = VectorType(result_target_type, width=condition_target_type.width)
300
301
            if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
                condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
302

Martin Bauer's avatar
Martin Bauer committed
303
            casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
Martin Bauer's avatar
Martin Bauer committed
304
                              for a, t in zip(new_results, types_of_results)]
305

Martin Bauer's avatar
Martin Bauer committed
306
            casted_conditions = [cast_func(a, condition_target_type)
Martin Bauer's avatar
Martin Bauer committed
307
308
                                 if t != condition_target_type and a is not True else a
                                 for a, t in zip(new_conditions, types_of_conditions)]
309

Martin Bauer's avatar
Martin Bauer committed
310
            return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
311
312
313
        else:
            return expr

Martin Bauer's avatar
Martin Bauer committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    def visit_node(node, substitution_dict):
        substitution_dict = substitution_dict.copy()
        for arg in node.args:
            if isinstance(arg, ast.SympyAssignment):
                assignment = arg
                subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                      skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                assignment.rhs = visit_expr(subs_expr)
                rhs_type = get_type_of_expression(assignment.rhs)
                if isinstance(assignment.lhs, TypedSymbol):
                    lhs_type = assignment.lhs.dtype
                    if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
                        new_lhs_type = VectorType(lhs_type, rhs_type.width)
                        new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                        substitution_dict[assignment.lhs] = new_lhs
                        assignment.lhs = new_lhs
Martin Bauer's avatar
Martin Bauer committed
330
331
                elif isinstance(assignment.lhs, vector_memory_access):
                    assignment.lhs = visit_expr(assignment.lhs)
332
333
334
335
336
            elif isinstance(arg, ast.Conditional):
                arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
                                               skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                arg.condition_expr = visit_expr(arg.condition_expr)
                visit_node(arg, substitution_dict)
Martin Bauer's avatar
Martin Bauer committed
337
338
339
340
            else:
                visit_node(arg, substitution_dict)

    visit_node(ast_node, {})