simd_instruction_sets.py 6.06 KB
Newer Older
1
2


Martin Bauer's avatar
Martin Bauer committed
3
# noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
4
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
5
6
7
8
9
10
11
12
    comparisons = {
        '==': '_CMP_EQ_UQ',
        '!=': '_CMP_NEQ_UQ',
        '>=': '_CMP_GE_OQ',
        '<=': '_CMP_LE_OQ',
        '<': '_CMP_NGE_UQ',
        '>': '_CMP_NLE_UQ',
    }
Martin Bauer's avatar
Martin Bauer committed
13
    base_names = {
14
15
16
17
        '+': 'add[0, 1]',
        '-': 'sub[0, 1]',
        '*': 'mul[0, 1]',
        '/': 'div[0, 1]',
18
19
        '&': 'and[0, 1]',
        '|': 'or[0, 1]',
20
21
22
23
        'blendv': 'blendv[0, 1, 2]',

        'sqrt': 'sqrt[0]',

24
        'makeVecConst': 'set[]',
25
        'makeVec': 'set[]',
26
27
        'makeVecBool': 'set[]',
        'makeVecConstBool': 'set[]',
28
29
30
31
32
        'makeZero': 'setzero[]',

        'loadU': 'loadu[0]',
        'loadA': 'load[0]',
        'storeU': 'storeu[0,1]',
Martin Bauer's avatar
Martin Bauer committed
33
34
        'storeA': 'store[0,1]',
        'stream': 'stream[0,1]',
Martin Bauer's avatar
Martin Bauer committed
35
36
        'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
        'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
37
    }
Martin Bauer's avatar
Martin Bauer committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    if instruction_set == 'avx512':
        base_names.update({
            'maskStore': 'mask_store[0, 2, 1]',
            'maskStoreU': 'mask_storeu[0, 2, 1]',
            'maskLoad': 'mask_load[2, 1, 0]',
            'maskLoadU': 'mask_loadu[2, 1, 0]'
        })
    if instruction_set == 'avx':
        base_names.update({
            'maskStore': 'maskstore[0, 2, 1]',
            'maskStoreU': 'maskstore[0, 2, 1]',
            'maskLoad': 'maskload[0, 1]',
            'maskLoadU': 'maskloadu[0, 1]'
        })

53
54
    for comparison_op, constant in comparisons.items():
        base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,)
55
56

    headers = {
57
        'avx512': ['<immintrin.h>'],
58
        'avx': ['<immintrin.h>'],
59
60
        'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
                '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    }

    suffix = {
        'double': 'pd',
        'float': 'ps',
    }
    prefix = {
        'sse': '_mm',
        'avx': '_mm256',
        'avx512': '_mm512',
    }

    width = {
        ("double", "sse"): 2,
        ("float", "sse"): 4,
        ("double", "avx"): 4,
        ("float", "avx"): 8,
        ("double", "avx512"): 8,
        ("float", "avx512"): 16,
    }

82
83
84
    result = {
        'width': width[(data_type, instruction_set)],
    }
Martin Bauer's avatar
Martin Bauer committed
85
86
    pre = prefix[instruction_set]
    suf = suffix[data_type]
Martin Bauer's avatar
Martin Bauer committed
87
    for intrinsic_id, function_shortcut in base_names.items():
Martin Bauer's avatar
Martin Bauer committed
88
89
        function_shortcut = function_shortcut.strip()
        name = function_shortcut[:function_shortcut.index('[')]
90

91
        if intrinsic_id == 'makeVecConst':
92
            arg_string = "({})".format(",".join(["{0}"] * result['width']))
93
94
95
96
97
98
99
100
101
        elif intrinsic_id == 'makeVec':
            params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
            arg_string = "({})".format(",".join(params))
        elif intrinsic_id == 'makeVecBool':
            params = ["(({{{i}}} ? -1.0 : 0.0)".format(i=i) for i in reversed(range(result['width']))]
            arg_string = "({})".format(",".join(params))
        elif intrinsic_id == 'makeVecConstBool':
            params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
            arg_string = "({})".format(",".join(params))
102
103
104
105
106
107
108
109
110
111
112
113
        else:
            args = function_shortcut[function_shortcut.index('[') + 1: -1]
            arg_string = "("
            for arg in args.split(","):
                arg = arg.strip()
                if not arg:
                    continue
                if arg in ('0', '1', '2', '3', '4', '5'):
                    arg_string += "{" + arg + "},"
                else:
                    arg_string += arg + ","
            arg_string = arg_string[:-1] + ")"
114
115
        mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
        result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
116
117
118
119
120
121

    result['dataTypePrefix'] = {
        'double': "_" + pre + 'd',
        'float': "_" + pre,
    }

122
    result['rsqrt'] = None
123
    bit_width = result['width'] * (64 if data_type == 'double' else 32)
Martin Bauer's avatar
Martin Bauer committed
124
125
126
127
    result['double'] = "__m%dd" % (bit_width,)
    result['float'] = "__m%d" % (bit_width,)
    result['int'] = "__m%di" % (bit_width,)
    result['bool'] = "__m%dd" % (bit_width,)
128

Martin Bauer's avatar
Martin Bauer committed
129
    result['headers'] = headers[instruction_set]
130
131
    result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf)
    result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf)
132

133
134
135
136
    if instruction_set == 'avx512':
        size = 8 if data_type == 'double' else 16
        result['&'] = '_kand_mask%d({0}, {1})' % (size,)
        result['|'] = '_kor_mask%d({0}, {1})' % (size,)
137
138
        result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, )
        result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, )
139
140
141
142
        result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
        result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
        result['bool'] = "__mmask%d" % (size,)

143
144
145
146
147
        params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
        result['makeVecBool'] = "__mmask8(({}) )".format(params)
        params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
        result['makeVecConstBool'] = "__mmask8(({}) )".format(params)

148
    if instruction_set == 'avx' and data_type == 'float':
149
150
        result['rsqrt'] = "_mm256_rsqrt_ps({0})"

151
    return result
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172


def get_supported_instruction_sets():
    """List of supported instruction sets on current hardware, or None if query failed."""
    try:
        from cpuinfo import get_cpu_info
    except ImportError:
        return None

    result = []
    required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
    required_avx_flags = {'avx'}
    required_avx512_flags = {'avx512f'}
    flags = set(get_cpu_info()['flags'])
    if flags.issuperset(required_sse_flags):
        result.append("sse")
    if flags.issuperset(required_avx_flags):
        result.append("avx")
    if flags.issuperset(required_avx512_flags):
        result.append("avx512")
    return result