simd_instruction_sets.py 4.37 KB
Newer Older
1
2


Martin Bauer's avatar
Martin Bauer committed
3
# noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
4
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
5
6
7
8
9
10
11
12
    comparisons = {
        '==': '_CMP_EQ_UQ',
        '!=': '_CMP_NEQ_UQ',
        '>=': '_CMP_GE_OQ',
        '<=': '_CMP_LE_OQ',
        '<': '_CMP_NGE_UQ',
        '>': '_CMP_NLE_UQ',
    }
Martin Bauer's avatar
Martin Bauer committed
13
    base_names = {
14
15
16
17
        '+': 'add[0, 1]',
        '-': 'sub[0, 1]',
        '*': 'mul[0, 1]',
        '/': 'div[0, 1]',
18
19
        '&': 'and[0, 1]',
        '|': 'or[0, 1]',
20
21
22
23
        'blendv': 'blendv[0, 1, 2]',

        'sqrt': 'sqrt[0]',

24
        'makeVec': 'set[]',
25
26
27
28
29
        'makeZero': 'setzero[]',

        'loadU': 'loadu[0]',
        'loadA': 'load[0]',
        'storeU': 'storeu[0,1]',
Martin Bauer's avatar
Martin Bauer committed
30
31
        'storeA': 'store[0,1]',
        'stream': 'stream[0,1]',
32
    }
33
34
    for comparison_op, constant in comparisons.items():
        base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,)
35
36

    headers = {
37
        'avx512': ['<immintrin.h>'],
38
        'avx': ['<immintrin.h>'],
39
40
        'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
                '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    }

    suffix = {
        'double': 'pd',
        'float': 'ps',
    }
    prefix = {
        'sse': '_mm',
        'avx': '_mm256',
        'avx512': '_mm512',
    }

    width = {
        ("double", "sse"): 2,
        ("float", "sse"): 4,
        ("double", "avx"): 4,
        ("float", "avx"): 8,
        ("double", "avx512"): 8,
        ("float", "avx512"): 16,
    }

62
63
64
    result = {
        'width': width[(data_type, instruction_set)],
    }
Martin Bauer's avatar
Martin Bauer committed
65
66
    pre = prefix[instruction_set]
    suf = suffix[data_type]
Martin Bauer's avatar
Martin Bauer committed
67
    for intrinsic_id, function_shortcut in base_names.items():
Martin Bauer's avatar
Martin Bauer committed
68
69
        function_shortcut = function_shortcut.strip()
        name = function_shortcut[:function_shortcut.index('[')]
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

        if intrinsic_id == 'makeVec':
            arg_string = "({})".format(",".join(["{0}"] * result['width']))
        else:
            args = function_shortcut[function_shortcut.index('[') + 1: -1]
            arg_string = "("
            for arg in args.split(","):
                arg = arg.strip()
                if not arg:
                    continue
                if arg in ('0', '1', '2', '3', '4', '5'):
                    arg_string += "{" + arg + "},"
                else:
                    arg_string += arg + ","
            arg_string = arg_string[:-1] + ")"
85
86
        mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
        result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
87
88
89
90
91
92

    result['dataTypePrefix'] = {
        'double': "_" + pre + 'd',
        'float': "_" + pre,
    }

93
    result['rsqrt'] = None
94
    bit_width = result['width'] * (64 if data_type == 'double' else 32)
Martin Bauer's avatar
Martin Bauer committed
95
96
97
98
    result['double'] = "__m%dd" % (bit_width,)
    result['float'] = "__m%d" % (bit_width,)
    result['int'] = "__m%di" % (bit_width,)
    result['bool'] = "__m%dd" % (bit_width,)
99

Martin Bauer's avatar
Martin Bauer committed
100
    result['headers'] = headers[instruction_set]
101
102
    result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf)
    result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf)
103

104
105
106
107
    if instruction_set == 'avx512':
        size = 8 if data_type == 'double' else 16
        result['&'] = '_kand_mask%d({0}, {1})' % (size,)
        result['|'] = '_kor_mask%d({0}, {1})' % (size,)
108
109
        result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, )
        result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, )
110
111
112
113
114
        result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
        result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
        result['bool'] = "__mmask%d" % (size,)

    if instruction_set == 'avx' and data_type == 'float':
115
116
        result['rsqrt'] = "_mm256_rsqrt_ps({0})"

117
    return result
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


def get_supported_instruction_sets():
    """List of supported instruction sets on current hardware, or None if query failed."""
    try:
        from cpuinfo import get_cpu_info
    except ImportError:
        return None

    result = []
    required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
    required_avx_flags = {'avx'}
    required_avx512_flags = {'avx512f'}
    flags = set(get_cpu_info()['flags'])
    if flags.issuperset(required_sse_flags):
        result.append("sse")
    if flags.issuperset(required_avx_flags):
        result.append("avx")
    if flags.issuperset(required_avx512_flags):
        result.append("avx512")
    return result