Skip to content
Snippets Groups Projects
Commit 4cdd0ad7 authored by Martin Bauer's avatar Martin Bauer
Browse files

Correct handling of AVX512 booleans/masks/blends

parent f7cda45b
Branches
Tags
No related merge requests found
......@@ -2,18 +2,19 @@
# noinspection SpellCheckingInspection
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'==': 'cmp[0, 1, _CMP_EQ_UQ ]',
'!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
'>=': 'cmp[0, 1, _CMP_GE_OQ ]',
'<=': 'cmp[0, 1, _CMP_LE_OQ ]',
'<': 'cmp[0, 1, _CMP_NGE_UQ ]',
'>': 'cmp[0, 1, _CMP_NLE_UQ ]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
......@@ -29,6 +30,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
}
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,)
headers = {
'avx512': ['<immintrin.h>'],
......@@ -79,13 +82,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string
mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
result['rsqrt'] = None
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = "__m%dd" % (bit_width,)
result['float'] = "__m%d" % (bit_width,)
......@@ -94,14 +99,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['headers'] = headers[instruction_set]
if instruction_set == 'avx512' and data_type == 'double':
result['rsqrt'] = "_mm512_rsqrt14_pd({0})"
elif instruction_set == 'avx512' and data_type == 'float':
result['rsqrt'] = "_mm512_rsqrt14_ps({0})"
elif instruction_set == 'avx' and data_type == 'float':
if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16
result['&'] = '_kand_mask%d({0}, {1})' % (size,)
result['|'] = '_kor_mask%d({0}, {1})' % (size,)
result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,)
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})"
else:
result['rsqrt'] = None
return result
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment