Skip to content
Snippets Groups Projects
Commit 4cdd0ad7 authored by Martin Bauer's avatar Martin Bauer
Browse files

Correct handling of AVX512 booleans/masks/blends

parent f7cda45b
Branches
Tags
No related merge requests found
...@@ -2,18 +2,19 @@ ...@@ -2,18 +2,19 @@
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
def get_vector_instruction_set(data_type='double', instruction_set='avx'): def get_vector_instruction_set(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = { base_names = {
'+': 'add[0, 1]', '+': 'add[0, 1]',
'-': 'sub[0, 1]', '-': 'sub[0, 1]',
'*': 'mul[0, 1]', '*': 'mul[0, 1]',
'/': 'div[0, 1]', '/': 'div[0, 1]',
'==': 'cmp[0, 1, _CMP_EQ_UQ ]',
'!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
'>=': 'cmp[0, 1, _CMP_GE_OQ ]',
'<=': 'cmp[0, 1, _CMP_LE_OQ ]',
'<': 'cmp[0, 1, _CMP_NGE_UQ ]',
'>': 'cmp[0, 1, _CMP_NLE_UQ ]',
'&': 'and[0, 1]', '&': 'and[0, 1]',
'|': 'or[0, 1]', '|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]', 'blendv': 'blendv[0, 1, 2]',
...@@ -29,6 +30,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -29,6 +30,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'storeA': 'store[0,1]', 'storeA': 'store[0,1]',
'stream': 'stream[0,1]', 'stream': 'stream[0,1]',
} }
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = 'cmp[0, 1, %s]' % (constant,)
headers = { headers = {
'avx512': ['<immintrin.h>'], 'avx512': ['<immintrin.h>'],
...@@ -79,13 +82,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -79,13 +82,15 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
else: else:
arg_string += arg + "," arg_string += arg + ","
arg_string = arg_string[:-1] + ")" arg_string = arg_string[:-1] + ")"
result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
result['dataTypePrefix'] = { result['dataTypePrefix'] = {
'double': "_" + pre + 'd', 'double': "_" + pre + 'd',
'float': "_" + pre, 'float': "_" + pre,
} }
result['rsqrt'] = None
bit_width = result['width'] * (64 if data_type == 'double' else 32) bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = "__m%dd" % (bit_width,) result['double'] = "__m%dd" % (bit_width,)
result['float'] = "__m%d" % (bit_width,) result['float'] = "__m%d" % (bit_width,)
...@@ -94,14 +99,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): ...@@ -94,14 +99,16 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
result['headers'] = headers[instruction_set] result['headers'] = headers[instruction_set]
if instruction_set == 'avx512' and data_type == 'double': if instruction_set == 'avx512':
result['rsqrt'] = "_mm512_rsqrt14_pd({0})" size = 8 if data_type == 'double' else 16
elif instruction_set == 'avx512' and data_type == 'float': result['&'] = '_kand_mask%d({0}, {1})' % (size,)
result['rsqrt'] = "_mm512_rsqrt14_ps({0})" result['|'] = '_kor_mask%d({0}, {1})' % (size,)
elif instruction_set == 'avx' and data_type == 'float': result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,)
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})" result['rsqrt'] = "_mm256_rsqrt_ps({0})"
else:
result['rsqrt'] = None
return result return result
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment