vectorization.py 16.8 KB
Newer Older
1
import warnings
Martin Bauer's avatar
Martin Bauer committed
2
3
from typing import Container, Union

4
import numpy as np
Martin Bauer's avatar
Martin Bauer committed
5
import sympy as sp
6
from sympy.logic.boolalg import BooleanFunction
Martin Bauer's avatar
Martin Bauer committed
7

8
import pystencils.astnodes as ast
Michael Kuron's avatar
Michael Kuron committed
9
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.data_types import (
11
    PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access)
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
14
15
from pystencils.integer_functions import modulo_ceil, modulo_floor
from pystencils.sympyextensions import fast_subs
16
from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one
Martin Bauer's avatar
Martin Bauer committed
17
18


19
20
# noinspection PyPep8Naming
class vec_any(sp.Function):
21
    nargs = (1,)
22
23
24
25


# noinspection PyPep8Naming
class vec_all(sp.Function):
26
    nargs = (1,)
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class NontemporalFence(ast.Node):
    def __init__(self):
        super(NontemporalFence, self).__init__(parent=None)

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    @property
    def args(self):
        return []

    def __eq__(self, other):
        return isinstance(other, NontemporalFence)


class CachelineSize(ast.Node):
    symbol = sp.Symbol("_clsize")
    mask_symbol = sp.Symbol("_clsize_mask")
    last_symbol = sp.Symbol("_cl_lastvec")
    
    def __init__(self):
        super(CachelineSize, self).__init__(parent=None)

    @property
    def symbols_defined(self):
        return set([self.symbol, self.mask_symbol, self.last_symbol])

    @property
    def undefined_symbols(self):
        return set()

    @property
    def args(self):
        return []

    def __eq__(self, other):
        return isinstance(other, CachelineSize)

Michael Kuron's avatar
Michael Kuron committed
72
73
74
    def __hash__(self):
        return hash(self.symbol)

75

Michael Kuron's avatar
Michael Kuron committed
76
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
77
              assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False,
78
              assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True):
Martin Bauer's avatar
Martin Bauer committed
79
80
81
82
    """Explicit vectorization using SIMD vectorization via intrinsics.

    Args:
        kernel_ast: abstract syntax tree (KernelFunction node)
83
        instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
Martin Bauer's avatar
Martin Bauer committed
84
85
86
87
88
89
90
        assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
                        used. If true, some of the loads are assumed to be from aligned memory addresses.
                        For example if x is the fastest coordinate, the access to center can be fetched via an
                        aligned-load instruction, for the west or east accesses potentially slower unaligend-load
                        instructions have to be used.
        nontemporal: a container of fields or field names for which nontemporal (streaming) stores are used.
                     If true, nontemporal access instructions are used for all fields.
91
92
        assume_inner_stride_one: kernels with non-constant inner loop bound and strides can not be vectorized since
                                 the inner loop stride is a runtime variable and thus might not be always 1.
93
                                 If this parameter is set to true, the inner stride is assumed to be always one.
94
                                 This has to be ensured at runtime!
95
96
97
98
99
        assume_sufficient_line_padding: if True and assume_inner_stride_one, no tail loop is created but loop is
                                        extended by at most (vector_width-1) elements
                                        assumes that at the end of each line there is enough padding with dummy data
                                        depending on the access pattern there might be additional padding
                                        required at the end of the array
Martin Bauer's avatar
Martin Bauer committed
100
    """
Michael Kuron's avatar
Michael Kuron committed
101
102
103
104
105
    if instruction_set == 'best':
        if get_supported_instruction_sets():
            instruction_set = get_supported_instruction_sets()[-1]
        else:
            instruction_set = 'avx'
106
107
    if instruction_set is None:
        return
108

Martin Bauer's avatar
Martin Bauer committed
109
110
111
112
113
114
    all_fields = kernel_ast.fields_accessed
    if nontemporal is None or nontemporal is False:
        nontemporal = {}
    elif nontemporal is True:
        nontemporal = all_fields

115
116
117
    if assume_inner_stride_one:
        replace_inner_stride_with_one(kernel_ast)

118
    field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
Martin Bauer's avatar
Martin Bauer committed
119
120
121
122
123
124
    if len(field_float_dtypes) != 1:
        raise NotImplementedError("Cannot vectorize kernels that contain accesses "
                                  "to differently typed floating point fields")
    float_size = field_float_dtypes.pop().numpy_dtype.itemsize
    assert float_size in (8, 4)
    vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
125
                                           instruction_set=instruction_set)
Martin Bauer's avatar
Martin Bauer committed
126
127
128
    vector_width = vector_is['width']
    kernel_ast.instruction_set = vector_is

129
    vectorize_rng(kernel_ast, vector_width)
Michael Kuron's avatar
Michael Kuron committed
130
131
132
    scattergather = 'scatter' in vector_is and 'gather' in vector_is
    vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
                                                scattergather, assume_sufficient_line_padding)
Martin Bauer's avatar
Martin Bauer committed
133
134
135
    insert_vector_casts(kernel_ast)


136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def vectorize_rng(kernel_ast, vector_width):
    """Replace scalar result symbols on RNG nodes with vectorial ones"""
    from pystencils.rng import RNGBase
    subst = {}

    def visit_node(node):
        for arg in node.args:
            if isinstance(arg, RNGBase):
                new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
                                      for s in arg.result_symbols]
                subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)})
                arg._symbols_defined = set(new_result_symbols)
            else:
                visit_node(arg)
    visit_node(kernel_ast)
    fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase))


154
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
Michael Kuron's avatar
Michael Kuron committed
155
                                                scattergather, assume_sufficient_line_padding):
Martin Bauer's avatar
Martin Bauer committed
156
    """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
Martin Bauer's avatar
Martin Bauer committed
157
158
    all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
    inner_loops = [n for n in all_loops if n.is_innermost_loop]
Markus Holzer's avatar
Markus Holzer committed
159
    zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops}
160

Martin Bauer's avatar
Martin Bauer committed
161
162
    for loop_node in inner_loops:
        loop_range = loop_node.stop - loop_node.start
163

Martin Bauer's avatar
Martin Bauer committed
164
        # cut off loop tail, that is not a multiple of four
165
166
167
168
169
170
        if assume_aligned and assume_sufficient_line_padding:
            loop_range = loop_node.stop - loop_node.start
            new_stop = loop_node.start + modulo_ceil(loop_range, vector_width)
            loop_node.stop = new_stop
        else:
            cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
Markus Holzer's avatar
Markus Holzer committed
171
            loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)]
Nils Kohl's avatar
Nils Kohl committed
172
173
174
            assert len(loop_nodes) in (0, 1, 2)  # 2 for main and tail loop, 1 if loop range divisible by vector width
            if len(loop_nodes) == 0:
                continue
175
            loop_node = loop_nodes[0]
176

177
        # Find all array accesses (indexed) that depend on the loop counter as offset
Martin Bauer's avatar
Martin Bauer committed
178
        loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
179
180
        substitutions = {}
        successful = True
Martin Bauer's avatar
Martin Bauer committed
181
        for indexed in loop_node.atoms(sp.Indexed):
182
            base, index = indexed.args
Martin Bauer's avatar
Martin Bauer committed
183
184
            if loop_counter_symbol in index.atoms(sp.Symbol):
                loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
Martin Bauer's avatar
Martin Bauer committed
185
                aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
Michael Kuron's avatar
Michael Kuron committed
186
187
                stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
                if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()):
188
189
                    successful = False
                    break
Martin Bauer's avatar
Martin Bauer committed
190
                typed_symbol = base.label
191
                assert type(typed_symbol.dtype) is PointerType, \
192
                    f"Type of access is {typed_symbol.dtype}, {indexed}"
Martin Bauer's avatar
Martin Bauer committed
193
194
195
196
197
198

                vec_type = VectorType(typed_symbol.dtype.base_type, vector_width)
                use_aligned_access = aligned_access and assume_aligned
                nontemporal = False
                if hasattr(indexed, 'field'):
                    nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
Michael Kuron's avatar
Michael Kuron committed
199
200
                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
                                                              stride if scattergather else 1)
201
                if nontemporal:
202
                    # insert NontemporalFence after the outermost loop
203
204
205
                    parent = loop_node.parent
                    while type(parent.parent.parent) is not ast.KernelFunction:
                        parent = parent.parent
206
                    parent.parent.insert_after(NontemporalFence(), parent, if_not_exists=True)
207
                    # insert CachelineSize at the beginning of the kernel
208
                    parent.parent.insert_front(CachelineSize(), if_not_exists=True)
209
210
211
        if not successful:
            warnings.warn("Could not vectorize loop because of non-consecutive memory access")
            continue
212

Martin Bauer's avatar
Martin Bauer committed
213
214
        loop_node.step = vector_width
        loop_node.subs(substitutions)
215
        vector_int_width = ast_node.instruction_set['intwidth']
216
217
        vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
            + cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width))
218
219
220

        fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))
221

Martin Bauer's avatar
Martin Bauer committed
222
223
224
225
226
227
        mask_conditionals(loop_node)


def mask_conditionals(loop_body):
    def visit_node(node, mask):
        if isinstance(node, ast.Conditional):
228
            cond = node.condition_expr
229
230
231
            skip = (loop_body.loop_counter_symbol not in cond.atoms(sp.Symbol)) or cond.func in (vec_all, vec_any)
            cond = True if skip else cond

232
            true_mask = sp.And(cond, mask)
Martin Bauer's avatar
Martin Bauer committed
233
234
235
236
            visit_node(node.true_block, true_mask)
            if node.false_block:
                false_mask = sp.And(sp.Not(node.condition_expr), mask)
                visit_node(node, false_mask)
237
238
            if not skip:
                node.condition_expr = vec_any(node.condition_expr)
Martin Bauer's avatar
Martin Bauer committed
239
240
        elif isinstance(node, ast.SympyAssignment):
            if mask is not True:
Michael Kuron's avatar
Michael Kuron committed
241
                s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
Martin Bauer's avatar
Martin Bauer committed
242
243
244
245
246
247
248
249
                     for ma in node.atoms(vector_memory_access)}
                node.subs(s)
        else:
            for arg in node.args:
                visit_node(arg, mask)

    visit_node(loop_body, mask=True)

250

Martin Bauer's avatar
Martin Bauer committed
251
252
253
def insert_vector_casts(ast_node):
    """Inserts necessary casts from scalar values to vector values."""

254
    handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all)
255

Martin Bauer's avatar
Martin Bauer committed
256
    def visit_expr(expr):
Martin Bauer's avatar
Martin Bauer committed
257
        if isinstance(expr, vector_memory_access):
Michael Kuron's avatar
Michael Kuron committed
258
            return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4]), *expr.args[5:])
Martin Bauer's avatar
Martin Bauer committed
259
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
260
            return expr
Markus Holzer's avatar
Markus Holzer committed
261
        elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
262
            new_arg = visit_expr(expr.args[0])
263
264
            base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \
                else get_type_of_expression(expr.args[0])
265
            pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)),
266
                              (new_arg, True))
267
            return visit_expr(pw)
268
        elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
269
270
271
272
            default_type = 'double'
            if expr.func is sp.Mul and expr.args[0] == -1:
                # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                dtype = int
273
274
                for arg in expr.atoms(vector_memory_access):
                    if arg.dtype.base_type.is_float():
275
                        dtype = arg.dtype.base_type.numpy_dtype.type
276
277
                for arg in expr.atoms(TypedSymbol):
                    if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
278
279
280
281
282
                        dtype = arg.dtype.base_type.numpy_dtype.type
                if dtype is not int:
                    if dtype is np.float32:
                        default_type = 'float'
                    expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
Martin Bauer's avatar
Martin Bauer committed
283
            new_args = [visit_expr(a) for a in expr.args]
284
            arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
Martin Bauer's avatar
Martin Bauer committed
285
            if not any(type(t) is VectorType for t in arg_types):
286
287
                return expr
            else:
Martin Bauer's avatar
Martin Bauer committed
288
                target_type = collate_types(arg_types)
289
290
291
                casted_args = [
                    cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a
                    for a, t in zip(new_args, arg_types)]
Martin Bauer's avatar
Martin Bauer committed
292
                return expr.func(*casted_args)
293
        elif expr.func is sp.Pow:
Martin Bauer's avatar
Martin Bauer committed
294
295
            new_arg = visit_expr(expr.args[0])
            return expr.func(new_arg, expr.args[1])
296
        elif expr.func == sp.Piecewise:
Martin Bauer's avatar
Martin Bauer committed
297
298
299
300
            new_results = [visit_expr(a[0]) for a in expr.args]
            new_conditions = [visit_expr(a[1]) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [get_type_of_expression(a) for a in new_conditions]
301

Martin Bauer's avatar
Martin Bauer committed
302
303
304
305
            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType:
                result_target_type = VectorType(result_target_type, width=condition_target_type.width)
306
307
            if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
                condition_target_type = VectorType(condition_target_type, width=result_target_type.width)
308

Martin Bauer's avatar
Martin Bauer committed
309
            casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
Martin Bauer's avatar
Martin Bauer committed
310
                              for a, t in zip(new_results, types_of_results)]
311

Martin Bauer's avatar
Martin Bauer committed
312
            casted_conditions = [cast_func(a, condition_target_type)
Martin Bauer's avatar
Martin Bauer committed
313
314
                                 if t != condition_target_type and a is not True else a
                                 for a, t in zip(new_conditions, types_of_conditions)]
315

Martin Bauer's avatar
Martin Bauer committed
316
            return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
317
318
319
        else:
            return expr

Martin Bauer's avatar
Martin Bauer committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    def visit_node(node, substitution_dict):
        substitution_dict = substitution_dict.copy()
        for arg in node.args:
            if isinstance(arg, ast.SympyAssignment):
                assignment = arg
                subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                      skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                assignment.rhs = visit_expr(subs_expr)
                rhs_type = get_type_of_expression(assignment.rhs)
                if isinstance(assignment.lhs, TypedSymbol):
                    lhs_type = assignment.lhs.dtype
                    if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
                        new_lhs_type = VectorType(lhs_type, rhs_type.width)
                        new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                        substitution_dict[assignment.lhs] = new_lhs
                        assignment.lhs = new_lhs
Martin Bauer's avatar
Martin Bauer committed
336
337
                elif isinstance(assignment.lhs, vector_memory_access):
                    assignment.lhs = visit_expr(assignment.lhs)
338
339
340
341
342
            elif isinstance(arg, ast.Conditional):
                arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
                                               skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                arg.condition_expr = visit_expr(arg.condition_expr)
                visit_node(arg, substitution_dict)
Martin Bauer's avatar
Martin Bauer committed
343
344
345
346
            else:
                visit_node(arg, substitution_dict)

    visit_node(ast_node, {})