Commit f19bd423 authored by Michael Kuron's avatar Michael Kuron
Browse files

Merge branch 'Neon_Intrinsics' into 'master'

Neon intrinsics

See merge request pycodegen/pystencils!188
parents 6effd8d3 5e927f6e
def get_argument_string(function_shortcut):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q_registers=True):
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'sqrt': 'sqrt[0]',
'loadU': 'ld1[0]',
'loadA': 'ld1[0]',
'storeU': 'st1[0, 1]',
'storeA': 'st1[0, 1]',
'stream': 'st1[0, 1]',
'abs': 'abs[0]',
'==': 'ceq[0, 1]',
'<=': 'cle[0, 1]',
'<': 'clt[0, 1]',
'>=': 'cge[0, 1]',
'>': 'cgt[0, 1]',
# '&': 'and[0, 1]', -> only for integer values available
# '|': 'orr[0, 1]'
}
bits = {'double': 64,
'float': 32}
if q_registers is True:
q_reg = 'q'
width = 128 // bits[data_type]
suffix = f'q_f{bits[data_type]}'
else:
q_reg = ''
width = 64 // bits[data_type]
suffix = f'_f{bits[data_type]}'
result = dict()
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
arg_string = get_argument_string(function_shortcut)
result[intrinsic_id] = 'v' + name + suffix + arg_string
result['makeVecConst'] = 'vdup' + q_reg + '_n_f' + str(bits[data_type]) + '({0})'
result['makeVec'] = 'vdup' + q_reg + '_n_f' + str(bits[data_type]) + '({0})'
result['rsqrt'] = None
result['width'] = width
result['double'] = 'float64x' + str(width) + '_t'
result['float'] = 'float32x' + str(width * 2) + '_t'
result['headers'] = ['<arm_neon.h>']
result['!='] = 'vmvnq_u%d(%s)' % (bits[data_type], result['=='])
return result
......@@ -533,6 +533,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert self.instruction_set['width'] == expr_type.width
return None
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
return self.instruction_set['abs'].format(self._print(expr.args[0]))
return super()._print_Abs(expr)
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask = expr.args
......
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
# noinspection SpellCheckingInspection
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'makeZero': 'setzero[]',
'loadU': 'loadu[0]',
'loadA': 'load[0]',
'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
}
if instruction_set == 'avx512':
base_names.update({
'maskStore': 'mask_store[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]',
'maskLoad': 'mask_load[2, 1, 0]',
'maskLoadU': 'mask_loadu[2, 1, 0]'
})
if instruction_set == 'avx':
base_names.update({
'maskStore': 'maskstore[0, 2, 1]',
'maskStoreU': 'maskstore[0, 2, 1]',
'maskLoad': 'maskload[0, 1]',
'maskLoadU': 'maskloadu[0, 1]'
})
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = f'cmp[0, 1, {constant}]'
headers = {
'avx512': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
'<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
suffix = {
'double': 'pd',
'float': 'ps',
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
}
result = {
'width': width[(data_type, instruction_set)],
}
pre = prefix[instruction_set]
suf = suffix[data_type]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if intrinsic_id == 'makeVecConst':
arg_string = f"({','.join(['{0}'] * result['width'])})"
elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(result['width']))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecBool':
params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(result['width']))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])]
arg_string = f"({','.join(params)})"
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
result['rsqrt'] = None
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = "__m%dd" % (bit_width,)
result['float'] = "__m%d" % (bit_width,)
result['int'] = "__m%di" % (bit_width,)
result['bool'] = "__m%dd" % (bit_width,)
result['headers'] = headers[instruction_set]
result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf)
result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf)
if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16
result['&'] = '_kand_mask%d({0}, {1})' % (size,)
result['|'] = '_kor_mask%d({0}, {1})' % (size,)
result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, )
result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, )
result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf)
result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,)
result['bool'] = "__mmask%d" % (size,)
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )"
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = "_mm256_rsqrt_ps({0})"
return result
def get_vector_instruction_set(data_type='double', instruction_set='avx', q_registers=True):
if instruction_set in ['neon', 'sve']:
return get_vector_instruction_set_arm(data_type, instruction_set, q_registers)
else:
return get_vector_instruction_set_x86(data_type, instruction_set)
def get_supported_instruction_sets():
......@@ -162,6 +20,7 @@ def get_supported_instruction_sets():
required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'}
required_avx_flags = {'avx'}
required_avx512_flags = {'avx512f'}
required_neon_flags = {'neon'}
flags = set(get_cpu_info()['flags'])
if flags.issuperset(required_sse_flags):
result.append("sse")
......@@ -169,4 +28,6 @@ def get_supported_instruction_sets():
result.append("avx")
if flags.issuperset(required_avx512_flags):
result.append("avx512")
if flags.issuperset(required_neon_flags):
result.append("neon")
return result
def get_argument_string(intrinsic_id, width, function_shortcut):
if intrinsic_id == 'makeVecConst':
arg_string = f"({','.join(['{0}'] * width)})"
elif intrinsic_id == 'makeVec':
params = ["{" + str(i) + "}" for i in reversed(range(width))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecBool':
params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(width))]
arg_string = f"({','.join(params)})"
elif intrinsic_id == 'makeVecConstBool':
params = ["(({0}) ? -1.0 : 0.0)" for _ in range(width)]
arg_string = f"({','.join(params)})"
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
comparisons = {
'==': '_CMP_EQ_UQ',
'!=': '_CMP_NEQ_UQ',
'>=': '_CMP_GE_OQ',
'<=': '_CMP_LE_OQ',
'<': '_CMP_NGE_UQ',
'>': '_CMP_NLE_UQ',
}
base_names = {
'+': 'add[0, 1]',
'-': 'sub[0, 1]',
'*': 'mul[0, 1]',
'/': 'div[0, 1]',
'&': 'and[0, 1]',
'|': 'or[0, 1]',
'blendv': 'blendv[0, 1, 2]',
'sqrt': 'sqrt[0]',
'makeVecConst': 'set[]',
'makeVec': 'set[]',
'makeVecBool': 'set[]',
'makeVecConstBool': 'set[]',
'loadU': 'loadu[0]',
'loadA': 'load[0]',
'storeU': 'storeu[0,1]',
'storeA': 'store[0,1]',
'stream': 'stream[0,1]',
'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]',
'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]'
}
if instruction_set == 'avx512':
base_names.update({
'maskStore': 'mask_store[0, 2, 1]',
'maskStoreU': 'mask_storeu[0, 2, 1]',
'maskLoad': 'mask_load[2, 1, 0]',
'maskLoadU': 'mask_loadu[2, 1, 0]'
})
if instruction_set == 'avx':
base_names.update({
'maskStore': 'maskstore[0, 2, 1]',
'maskStoreU': 'maskstore[0, 2, 1]',
'maskLoad': 'maskload[0, 1]',
'maskLoadU': 'maskloadu[0, 1]'
})
for comparison_op, constant in comparisons.items():
base_names[comparison_op] = f'cmp[0, 1, {constant}]'
headers = {
'avx512': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>',
'<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
suffix = {
'double': 'pd',
'float': 'ps',
}
prefix = {
'sse': '_mm',
'avx': '_mm256',
'avx512': '_mm512',
}
width = {
("double", "sse"): 2,
("float", "sse"): 4,
("double", "avx"): 4,
("float", "avx"): 8,
("double", "avx512"): 8,
("float", "avx512"): 16,
}
result = {
'width': width[(data_type, instruction_set)],
}
pre = prefix[instruction_set]
suf = suffix[data_type]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut)
mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else ''
result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
result['rsqrt'] = None
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = f"__m{bit_width}d"
result['float'] = f"__m{bit_width}"
result['int'] = f"__m{bit_width}i"
result['bool'] = f"__m{bit_width}d"
result['headers'] = headers[instruction_set]
result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0"
result['all'] = f"{pre}_movemask_{suf}({{0}}) == 0xF"
if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16
result['&'] = f'_kand_mask{size}({{0}}, {{1}})'
result['|'] = f'_kor_mask{size}({{0}}, {{1}})'
result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})'
result['all'] = f'_kortestc_mask{size}_u8({{0}}, {{0}})'
result['blendv'] = f'{pre}_mask_blend_{suf}({{2}}, {{0}}, {{1}})'
result['rsqrt'] = f"{pre}_rsqrt14_{suf}({{0}})"
result['abs'] = f"{pre}_abs_{suf}({{0}})"
result['bool'] = f"__mmask{size}"
params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)])
result['makeVecBool'] = f"__mmask8(({params}) )"
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )"
if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
return result
......@@ -176,7 +176,7 @@ def insert_vector_casts(ast_node):
visit_expr(expr.args[4]))
elif isinstance(expr, cast_func):
return expr
elif expr.func is sp.Abs:
elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
new_arg = visit_expr(expr.args[0])
pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True))
return visit_expr(pw)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment