diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2f51cd6154a63beab7c4b6939e8f41b6ad976d --- /dev/null +++ b/pystencils/backends/arm_instruction_sets.py @@ -0,0 +1,75 @@ +def get_argument_string(function_shortcut): + args = function_shortcut[function_shortcut.index('[') + 1: -1] + arg_string = "(" + for arg in args.split(","): + arg = arg.strip() + if not arg: + continue + if arg in ('0', '1', '2', '3', '4', '5'): + arg_string += "{" + arg + "}," + else: + arg_string += arg + "," + arg_string = arg_string[:-1] + ")" + return arg_string + + +def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q_registers=True): + base_names = { + '+': 'add[0, 1]', + '-': 'sub[0, 1]', + '*': 'mul[0, 1]', + '/': 'div[0, 1]', + 'sqrt': 'sqrt[0]', + + 'loadU': 'ld1[0]', + 'loadA': 'ld1[0]', + 'storeU': 'st1[0, 1]', + 'storeA': 'st1[0, 1]', + 'stream': 'st1[0, 1]', + + 'abs': 'abs[0]', + '==': 'ceq[0, 1]', + '<=': 'cle[0, 1]', + '<': 'clt[0, 1]', + '>=': 'cge[0, 1]', + '>': 'cgt[0, 1]', + # '&': 'and[0, 1]', -> only for integer values available + # '|': 'orr[0, 1]' + + } + + bits = {'double': 64, + 'float': 32} + + if q_registers is True: + q_reg = 'q' + width = 128 // bits[data_type] + suffix = f'q_f{bits[data_type]}' + else: + q_reg = '' + width = 64 // bits[data_type] + suffix = f'_f{bits[data_type]}' + + result = dict() + + for intrinsic_id, function_shortcut in base_names.items(): + function_shortcut = function_shortcut.strip() + name = function_shortcut[:function_shortcut.index('[')] + + arg_string = get_argument_string(function_shortcut) + + result[intrinsic_id] = 'v' + name + suffix + arg_string + + result['makeVecConst'] = 'vdup' + q_reg + '_n_f' + str(bits[data_type]) + '({0})' + result['makeVec'] = 'vdup' + q_reg + '_n_f' + str(bits[data_type]) + '({0})' + + result['rsqrt'] = None + + result['width'] = width + result['double'] = 'float64x' + str(width) + '_t' + result['float'] = 'float32x' + str(width * 2) + '_t' + result['headers'] = ['<arm_neon.h>'] + + result['!='] = 'vmvnq_u%d(%s)' % (bits[data_type], result['==']) + + return result diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index ae29ee1da7fcda7640c669d1c155d4d86aa37fbe..27d6480124ee8b64d6caac19bac65aed20e57d84 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -533,6 +533,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): assert self.instruction_set['width'] == expr_type.width return None + def _print_Abs(self, expr): + if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): + return self.instruction_set['abs'].format(self._print(expr.args[0])) + return super()._print_Abs(expr) + def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _, mask = expr.args diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index 7e2be4dee3d1d6a0203c91dad91369ae1a9891a6..c6290aa45b2ac0451c671d11ace2fe1d40d86264 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -1,154 +1,12 @@ +from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86 +from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm -# noinspection SpellCheckingInspection -def get_vector_instruction_set(data_type='double', instruction_set='avx'): - comparisons = { - '==': '_CMP_EQ_UQ', - '!=': '_CMP_NEQ_UQ', - '>=': '_CMP_GE_OQ', - '<=': '_CMP_LE_OQ', - '<': '_CMP_NGE_UQ', - '>': '_CMP_NLE_UQ', - } - base_names = { - '+': 'add[0, 1]', - '-': 'sub[0, 1]', - '*': 'mul[0, 1]', - '/': 'div[0, 1]', - '&': 'and[0, 1]', - '|': 'or[0, 1]', - 'blendv': 'blendv[0, 1, 2]', - - 'sqrt': 'sqrt[0]', - - 'makeVecConst': 'set[]', - 'makeVec': 'set[]', - 'makeVecBool': 'set[]', - 'makeVecConstBool': 'set[]', - 'makeZero': 'setzero[]', - - 'loadU': 'loadu[0]', - 'loadA': 'load[0]', - 'storeU': 'storeu[0,1]', - 'storeA': 'store[0,1]', - 'stream': 'stream[0,1]', - 'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]', - 'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]' - } - if instruction_set == 'avx512': - base_names.update({ - 'maskStore': 'mask_store[0, 2, 1]', - 'maskStoreU': 'mask_storeu[0, 2, 1]', - 'maskLoad': 'mask_load[2, 1, 0]', - 'maskLoadU': 'mask_loadu[2, 1, 0]' - }) - if instruction_set == 'avx': - base_names.update({ - 'maskStore': 'maskstore[0, 2, 1]', - 'maskStoreU': 'maskstore[0, 2, 1]', - 'maskLoad': 'maskload[0, 1]', - 'maskLoadU': 'maskloadu[0, 1]' - }) - - for comparison_op, constant in comparisons.items(): - base_names[comparison_op] = f'cmp[0, 1, {constant}]' - - headers = { - 'avx512': ['<immintrin.h>'], - 'avx': ['<immintrin.h>'], - 'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', - '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>'] - } - - suffix = { - 'double': 'pd', - 'float': 'ps', - } - prefix = { - 'sse': '_mm', - 'avx': '_mm256', - 'avx512': '_mm512', - } - - width = { - ("double", "sse"): 2, - ("float", "sse"): 4, - ("double", "avx"): 4, - ("float", "avx"): 8, - ("double", "avx512"): 8, - ("float", "avx512"): 16, - } - - result = { - 'width': width[(data_type, instruction_set)], - } - pre = prefix[instruction_set] - suf = suffix[data_type] - for intrinsic_id, function_shortcut in base_names.items(): - function_shortcut = function_shortcut.strip() - name = function_shortcut[:function_shortcut.index('[')] - - if intrinsic_id == 'makeVecConst': - arg_string = f"({','.join(['{0}'] * result['width'])})" - elif intrinsic_id == 'makeVec': - params = ["{" + str(i) + "}" for i in reversed(range(result['width']))] - arg_string = f"({','.join(params)})" - elif intrinsic_id == 'makeVecBool': - params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(result['width']))] - arg_string = f"({','.join(params)})" - elif intrinsic_id == 'makeVecConstBool': - params = ["(({0}) ? -1.0 : 0.0)" for _ in range(result['width'])] - arg_string = f"({','.join(params)})" - else: - args = function_shortcut[function_shortcut.index('[') + 1: -1] - arg_string = "(" - for arg in args.split(","): - arg = arg.strip() - if not arg: - continue - if arg in ('0', '1', '2', '3', '4', '5'): - arg_string += "{" + arg + "}," - else: - arg_string += arg + "," - arg_string = arg_string[:-1] + ")" - mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else '' - result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string - - result['dataTypePrefix'] = { - 'double': "_" + pre + 'd', - 'float': "_" + pre, - } - - result['rsqrt'] = None - bit_width = result['width'] * (64 if data_type == 'double' else 32) - result['double'] = "__m%dd" % (bit_width,) - result['float'] = "__m%d" % (bit_width,) - result['int'] = "__m%di" % (bit_width,) - result['bool'] = "__m%dd" % (bit_width,) - - result['headers'] = headers[instruction_set] - result['any'] = "%s_movemask_%s({0}) > 0" % (pre, suf) - result['all'] = "%s_movemask_%s({0}) == 0xF" % (pre, suf) - - if instruction_set == 'avx512': - size = 8 if data_type == 'double' else 16 - result['&'] = '_kand_mask%d({0}, {1})' % (size,) - result['|'] = '_kor_mask%d({0}, {1})' % (size,) - result['any'] = '!_ktestz_mask%d_u8({0}, {0})' % (size, ) - result['all'] = '_kortestc_mask%d_u8({0}, {0})' % (size, ) - result['blendv'] = '%s_mask_blend_%s({2}, {0}, {1})' % (pre, suf) - result['rsqrt'] = "_mm512_rsqrt14_%s({0})" % (suf,) - result['bool'] = "__mmask%d" % (size,) - - params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)]) - result['makeVecBool'] = f"__mmask8(({params}) )" - params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) - result['makeVecConstBool'] = f"__mmask8(({params}) )" - - if instruction_set == 'avx' and data_type == 'float': - result['rsqrt'] = "_mm256_rsqrt_ps({0})" - - return result +def get_vector_instruction_set(data_type='double', instruction_set='avx', q_registers=True): + if instruction_set in ['neon', 'sve']: + return get_vector_instruction_set_arm(data_type, instruction_set, q_registers) + else: + return get_vector_instruction_set_x86(data_type, instruction_set) def get_supported_instruction_sets(): @@ -162,6 +20,7 @@ def get_supported_instruction_sets(): required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'} required_avx_flags = {'avx'} required_avx512_flags = {'avx512f'} + required_neon_flags = {'neon'} flags = set(get_cpu_info()['flags']) if flags.issuperset(required_sse_flags): result.append("sse") @@ -169,4 +28,6 @@ def get_supported_instruction_sets(): result.append("avx") if flags.issuperset(required_avx512_flags): result.append("avx512") + if flags.issuperset(required_neon_flags): + result.append("neon") return result diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..349c190e252f89cba9f04c8f6b338933dfa6b8e1 --- /dev/null +++ b/pystencils/backends/x86_instruction_sets.py @@ -0,0 +1,154 @@ +def get_argument_string(intrinsic_id, width, function_shortcut): + if intrinsic_id == 'makeVecConst': + arg_string = f"({','.join(['{0}'] * width)})" + elif intrinsic_id == 'makeVec': + params = ["{" + str(i) + "}" for i in reversed(range(width))] + arg_string = f"({','.join(params)})" + elif intrinsic_id == 'makeVecBool': + params = [f"(({{{i}}} ? -1.0 : 0.0)" for i in reversed(range(width))] + arg_string = f"({','.join(params)})" + elif intrinsic_id == 'makeVecConstBool': + params = ["(({0}) ? -1.0 : 0.0)" for _ in range(width)] + arg_string = f"({','.join(params)})" + else: + args = function_shortcut[function_shortcut.index('[') + 1: -1] + arg_string = "(" + for arg in args.split(","): + arg = arg.strip() + if not arg: + continue + if arg in ('0', '1', '2', '3', '4', '5'): + arg_string += "{" + arg + "}," + else: + arg_string += arg + "," + arg_string = arg_string[:-1] + ")" + return arg_string + + +def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): + comparisons = { + '==': '_CMP_EQ_UQ', + '!=': '_CMP_NEQ_UQ', + '>=': '_CMP_GE_OQ', + '<=': '_CMP_LE_OQ', + '<': '_CMP_NGE_UQ', + '>': '_CMP_NLE_UQ', + } + base_names = { + '+': 'add[0, 1]', + '-': 'sub[0, 1]', + '*': 'mul[0, 1]', + '/': 'div[0, 1]', + '&': 'and[0, 1]', + '|': 'or[0, 1]', + 'blendv': 'blendv[0, 1, 2]', + + 'sqrt': 'sqrt[0]', + + 'makeVecConst': 'set[]', + 'makeVec': 'set[]', + 'makeVecBool': 'set[]', + 'makeVecConstBool': 'set[]', + + 'loadU': 'loadu[0]', + 'loadA': 'load[0]', + 'storeU': 'storeu[0,1]', + 'storeA': 'store[0,1]', + 'stream': 'stream[0,1]', + 'maskstore': 'mask_store[0, 2, 1]' if instruction_set == 'avx512' else 'maskstore[0, 2, 1]', + 'maskload': 'mask_load[0, 2, 1]' if instruction_set == 'avx512' else 'maskload[0, 2, 1]' + } + if instruction_set == 'avx512': + base_names.update({ + 'maskStore': 'mask_store[0, 2, 1]', + 'maskStoreU': 'mask_storeu[0, 2, 1]', + 'maskLoad': 'mask_load[2, 1, 0]', + 'maskLoadU': 'mask_loadu[2, 1, 0]' + }) + if instruction_set == 'avx': + base_names.update({ + 'maskStore': 'maskstore[0, 2, 1]', + 'maskStoreU': 'maskstore[0, 2, 1]', + 'maskLoad': 'maskload[0, 1]', + 'maskLoadU': 'maskloadu[0, 1]' + }) + + for comparison_op, constant in comparisons.items(): + base_names[comparison_op] = f'cmp[0, 1, {constant}]' + + headers = { + 'avx512': ['<immintrin.h>'], + 'avx': ['<immintrin.h>'], + 'sse': ['<immintrin.h>', '<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', + '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>'] + } + + suffix = { + 'double': 'pd', + 'float': 'ps', + } + prefix = { + 'sse': '_mm', + 'avx': '_mm256', + 'avx512': '_mm512', + } + + width = { + ("double", "sse"): 2, + ("float", "sse"): 4, + ("double", "avx"): 4, + ("float", "avx"): 8, + ("double", "avx512"): 8, + ("float", "avx512"): 16, + } + + result = { + 'width': width[(data_type, instruction_set)], + } + pre = prefix[instruction_set] + suf = suffix[data_type] + for intrinsic_id, function_shortcut in base_names.items(): + function_shortcut = function_shortcut.strip() + name = function_shortcut[:function_shortcut.index('[')] + + arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut) + + mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else '' + result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string + + result['dataTypePrefix'] = { + 'double': "_" + pre + 'd', + 'float': "_" + pre, + } + + result['rsqrt'] = None + bit_width = result['width'] * (64 if data_type == 'double' else 32) + result['double'] = f"__m{bit_width}d" + result['float'] = f"__m{bit_width}" + result['int'] = f"__m{bit_width}i" + result['bool'] = f"__m{bit_width}d" + + result['headers'] = headers[instruction_set] + result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0" + result['all'] = f"{pre}_movemask_{suf}({{0}}) == 0xF" + + if instruction_set == 'avx512': + size = 8 if data_type == 'double' else 16 + result['&'] = f'_kand_mask{size}({{0}}, {{1}})' + result['|'] = f'_kor_mask{size}({{0}}, {{1}})' + result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})' + result['all'] = f'_kortestc_mask{size}_u8({{0}}, {{0}})' + result['blendv'] = f'{pre}_mask_blend_{suf}({{2}}, {{0}}, {{1}})' + result['rsqrt'] = f"{pre}_rsqrt14_{suf}({{0}})" + result['abs'] = f"{pre}_abs_{suf}({{0}})" + result['bool'] = f"__mmask{size}" + + params = " | ".join(["({{{i}}} ? {power} : 0)".format(i=i, power=2 ** i) for i in range(8)]) + result['makeVecBool'] = f"__mmask8(({params}) )" + params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) + result['makeVecConstBool'] = f"__mmask8(({params}) )" + + if instruction_set == 'avx' and data_type == 'float': + result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" + + return result diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 0ee5200a069bdc00e1bb751ba5fb914b0707a62b..cf51456569b47e5fea1dfd7698f09e2416fe70b1 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -176,7 +176,7 @@ def insert_vector_casts(ast_node): visit_expr(expr.args[4])) elif isinstance(expr, cast_func): return expr - elif expr.func is sp.Abs: + elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0]) pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True)) return visit_expr(pw)