Skip to content
Snippets Groups Projects
Commit 5e927f6e authored by Markus Holzer's avatar Markus Holzer Committed by Michael Kuron
Browse files

Neon intrinsics

parent 6effd8d3
Branches
Tags
No related merge requests found
def get_argument_string(function_shortcut):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q_registers=True):
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'sqrt': 'sqrt[0]',
'loadU': 'ld1[0]',
'loadA': 'ld1[0]',
'storeU': 'st1[0, 1]',
'storeA': 'st1[0, 1]',
'stream': 'st1[0, 1]',
'abs': 'abs[0]',
'==': 'ceq[0, 1]',
'<=': 'cle[0, 1]',
'<': 'clt[0, 1]',
'>=': 'cge[0, 1]',
'>': 'cgt[0, 1]',
# '&': 'and[0, 1]', -> only for integer values available
# '|': 'orr[0, 1]'
}
bits = {'double': 64,
'float': 32}
if q_registers is True:
q_reg = 'q'
width = 128 // bits[data_type]
suffix = f'q_f{bits[data_type]}'
else:
q_reg = ''
width = 64 // bits[data_type]
suffix = f'_f{bits[data_type]}'
result = dict()
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
arg_string = get_argument_string(function_shortcut)
result[intrinsic_id] = 'v' + name + suffix + arg_string
result['makeVecConst'] = 'vdup' + q_reg + '_n_f' + str(bits[data_type]) + '({0})'
result['makeVec'] = 'vdup' + q_reg + '_n_f' + str(bits[data_type]) + '({0})'
result['rsqrt'] = None
result['width'] = width
result['double'] = 'float64x' + str(width) + '_t'
result['float'] = 'float32x' + str(width * 2) + '_t'
result['headers'] = ['<arm_neon.h>']
result['!='] = 'vmvnq_u%d(%s)' % (bits[data_type], result['=='])
return result
......@@ -533,6 +533,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert self.instruction_set['width'] == expr_type.width
return None
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
return self.instruction_set['abs'].format(self._print(expr.args[0]))
return super()._print_Abs(expr)
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask = expr.args
......
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
# noinspection SpellCheckingInspection
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'makeZero': 'setzero[]',
'loadU': 'loadu[0]',
'loadA': 'load[0]',
'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
}
if instruction_set == 'avx512':
base_names.update({
'maskStore': 'mask_store[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]',
'maskLoad': 'mask_load[2, 1, 0]',
'maskLoadU': 'mask_loadu[2, 1, 0]'
})
if instruction_set == 'avx':
base_names.update({
'maskStore': 'maskstore[0, 2, 1]',
'maskStoreU': 'maskstore[0, 2, 1]',
'maskLoad': 'maskload[0, 1]',
'maskLoadU': 'maskloadu[0, 1]'
})
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = f'cmp[0, 1, {constant}]'
headers = {
'avx512': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
'<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
suffix = {
'double': 'pd',
'float': 'ps',
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
}
result = {
'width': width[(data_type, instruction_set)],
}
pre = prefix[instruction_set]
suf = suffix[data_type]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if intrinsic_id == 'makeVecConst':
arg_string = f"({','.join(['{0}'] * result['width'])})"
elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecBool':
params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(result['width']))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
arg_string = f"({','.join(params)})"
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
result['rsqrt'] = None
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = "__m%dd" % (bit_width,)
result['float'] = "__m%d" % (bit_width,)
result['int'] = "__m%di" % (bit_width,)
result['bool'] = "__m%dd" % (bit_width,)
result['headers'] = headers[instruction_set]
result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf)
result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf)
if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16
result['&'] = '_kand_mask%d({0}, {1})' % (size,)
result['|'] = '_kor_mask%d({0}, {1})' % (size,)
result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, )
result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, )
result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,)
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )"
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})"
return result
def get_vector_instruction_set(data_type='double', instruction_set='avx', q_registers=True):
if instruction_set in ['neon', 'sve']:
return get_vector_instruction_set_arm(data_type, instruction_set, q_registers)
else:
return get_vector_instruction_set_x86(data_type, instruction_set)
def get_supported_instruction_sets():
......@@ -162,6 +20,7 @@ def get_supported_instruction_sets():
required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
required_avx_flags = {'avx'}
required_avx512_flags = {'avx512f'}
required_neon_flags = {'neon'}
flags = set(get_cpu_info()['flags'])
if flags.issuperset(required_sse_flags):
result.append("sse")
......@@ -169,4 +28,6 @@ def get_supported_instruction_sets():
result.append("avx")
if flags.issuperset(required_avx512_flags):
result.append("avx512")
if flags.issuperset(required_neon_flags):
result.append("neon")
return result
def get_argument_string(intrinsic_id, width, function_shortcut):
if intrinsic_id == 'makeVecConst':
arg_string = f"({','.join(['{0}'] * width)})"
elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(width))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecBool':
params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(width))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(width)]
arg_string = f"({','.join(params)})"
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'loadU': 'loadu[0]',
'loadA': 'load[0]',
'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
}
if instruction_set == 'avx512':
base_names.update({
'maskStore': 'mask_store[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]',
'maskLoad': 'mask_load[2, 1, 0]',
'maskLoadU': 'mask_loadu[2, 1, 0]'
})
if instruction_set == 'avx':
base_names.update({
'maskStore': 'maskstore[0, 2, 1]',
'maskStoreU': 'maskstore[0, 2, 1]',
'maskLoad': 'maskload[0, 1]',
'maskLoadU': 'maskloadu[0, 1]'
})
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = f'cmp[0, 1, {constant}]'
headers = {
'avx512': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
'<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
suffix = {
'double': 'pd',
'float': 'ps',
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
}
result = {
'width': width[(data_type, instruction_set)],
}
pre = prefix[instruction_set]
suf = suffix[data_type]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut)
mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
result['rsqrt'] = None
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = f"__m{bit_width}d"
result['float'] = f"__m{bit_width}"
result['int'] = f"__m{bit_width}i"
result['bool'] = f"__m{bit_width}d"
result['headers'] = headers[instruction_set]
result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0"
result['all'] = f"{pre}_movemask_{suf}({{0}}) == 0xF"
if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16
result['&'] = f'_kand_mask{size}({{0}}, {{1}})'
result['|'] = f'_kor_mask{size}({{0}}, {{1}})'
result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})'
result['all'] = f'_kortestc_mask{size}_u8({{0}}, {{0}})'
result['blendv'] = f'{pre}_mask_blend_{suf}({{2}}, {{0}}, {{1}})'
result['rsqrt'] = f"{pre}_rsqrt14_{suf}({{0}})"
result['abs'] = f"{pre}_abs_{suf}({{0}})"
result['bool'] = f"__mmask{size}"
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )"
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
return result
......@@ -176,7 +176,7 @@ def insert_vector_casts(ast_node):
visit_expr(expr.args[4]))
elif isinstance(expr, cast_func):
return expr
elif expr.func is sp.Abs:
elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
new_arg = visit_expr(expr.args[0])
pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True))
return visit_expr(pw)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment