diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py index 2d88352bb0833338e69f83e7b1cc3a4accde6d49..02207502dd6d6389be588deb7f135dee7dfa6c4f 100644 --- a/backends/simd_instruction_sets.py +++ b/backends/simd_instruction_sets.py @@ -2,18 +2,19 @@ # noinspection SpellCheckingInspection def get_vector_instruction_set(data_type='double', instruction_set='avx'): + comparisons = { + '==': '_CMP_EQ_UQ', + '!=': '_CMP_NEQ_UQ', + '>=': '_CMP_GE_OQ', + '<=': '_CMP_LE_OQ', + '<': '_CMP_NGE_UQ', + '>': '_CMP_NLE_UQ', + } base_names = { '+': 'add[0, 1]', '-': 'sub[0, 1]', '*': 'mul[0, 1]', '/': 'div[0, 1]', - - '==': 'cmp[0, 1, _CMP_EQ_UQ ]', - '!=': 'cmp[0, 1, _CMP_NEQ_UQ ]', - '>=': 'cmp[0, 1, _CMP_GE_OQ ]', - '<=': 'cmp[0, 1, _CMP_LE_OQ ]', - '<': 'cmp[0, 1, _CMP_NGE_UQ ]', - '>': 'cmp[0, 1, _CMP_NLE_UQ ]', '&': 'and[0, 1]', '|': 'or[0, 1]', 'blendv': 'blendv[0, 1, 2]', @@ -29,6 +30,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): 'storeA': 'store[0,1]', 'stream': 'stream[0,1]', } + for comparison_op, constant in comparisons.items(): + base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,) headers = { 'avx512': ['<immintrin.h>'], @@ -79,13 +82,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): else: arg_string += arg + "," arg_string = arg_string[:-1] + ")" - result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string + mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else '' + result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string result['dataTypePrefix'] = { 'double': "_" + pre + 'd', 'float': "_" + pre, } + result['rsqrt'] = None bit_width = result['width'] * (64 if data_type == 'double' else 32) result['double'] = "__m%dd" % (bit_width,) result['float'] = "__m%d" % (bit_width,) @@ -94,14 +99,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): result['headers'] = headers[instruction_set] - if instruction_set == 'avx512' and data_type == 'double': - result['rsqrt'] = "_mm512_rsqrt14_pd({0})" - elif instruction_set == 'avx512' and data_type == 'float': - result['rsqrt'] = "_mm512_rsqrt14_ps({0})" - elif instruction_set == 'avx' and data_type == 'float': + if instruction_set == 'avx512': + size = 8 if data_type == 'double' else 16 + result['&'] = '_kand_mask%d({0}, {1})' % (size,) + result['|'] = '_kor_mask%d({0}, {1})' % (size,) + result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf) + result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) + result['bool'] = "__mmask%d" % (size,) + + if instruction_set == 'avx' and data_type == 'float': result['rsqrt'] = "_mm256_rsqrt_ps({0})" - else: - result['rsqrt'] = None return result