diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py index a386253a098be738f9a6bc932144edebeaa4a1ea..9f7b4ee2292cf6131322c0544c369abda1704266 100644 --- a/pystencils/backends/arm_instruction_sets.py +++ b/pystencils/backends/arm_instruction_sets.py @@ -1,6 +1,8 @@ -def get_argument_string(function_shortcut): +def get_argument_string(function_shortcut, first=''): args = function_shortcut[function_shortcut.index('[') + 1: -1] arg_string = "(" + if first: + arg_string += first + ', ' for arg in args.split(","): arg = arg.strip() if not arg: @@ -14,8 +16,17 @@ def get_argument_string(function_shortcut): def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): - if instruction_set != 'neon': + if instruction_set != 'neon' and not instruction_set.startswith('sve'): raise NotImplementedError(instruction_set) + if instruction_set == 'sve': + raise NotImplementedError("sizeless SVE is not implemented") + + if instruction_set.startswith('sve'): + cmp = 'cmp' + bitwidth = int(instruction_set[3:]) + elif instruction_set == 'neon': + cmp = 'c' + bitwidth = 128 base_names = { '+': 'add[0, 1]', @@ -30,58 +41,94 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): 'storeA': 'st1[0, 1]', 'abs': 'abs[0]', - '==': 'ceq[0, 1]', - '<=': 'cle[0, 1]', - '<': 'clt[0, 1]', - '>=': 'cge[0, 1]', - '>': 'cgt[0, 1]', + '==': f'{cmp}eq[0, 1]', + '!=': f'{cmp}eq[0, 1]', + '<=': f'{cmp}le[0, 1]', + '<': f'{cmp}lt[0, 1]', + '>=': f'{cmp}ge[0, 1]', + '>': f'{cmp}gt[0, 1]', } bits = {'double': 64, 'float': 32, 'int': 32} - width = 128 // bits[data_type] - intwidth = 128 // bits['int'] - suffix = f'q_f{bits[data_type]}' + width = bitwidth // bits[data_type] + intwidth = bitwidth // bits['int'] + if instruction_set.startswith('sve'): + prefix = 'sv' + suffix = f'_f{bits[data_type]}' + elif instruction_set == 'neon': + prefix = 'v' + suffix = f'q_f{bits[data_type]}' result = dict() - result['bytes'] = 16 + result['bytes'] = bitwidth // 8 + + predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})' + int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})' for intrinsic_id, function_shortcut in base_names.items(): function_shortcut = function_shortcut.strip() name = function_shortcut[:function_shortcut.index('[')] - arg_string = get_argument_string(function_shortcut) + arg_string = get_argument_string(function_shortcut, first=predicate if prefix == 'sv' else '') + if prefix == 'sv' and not name.startswith('ld') and not name.startswith('st') and not name.startswith(cmp): + undef = '_x' + else: + undef = '' - result[intrinsic_id] = 'v' + name + suffix + arg_string + result[intrinsic_id] = prefix + name + suffix + undef + arg_string - result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' - result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in range(width)]) + \ - ')' - result['makeVecConstInt'] = f'vdupq_n_s{bits["int"]}' + '({0})' - result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})' + result['width'] = width + result['intwidth'] = intwidth - result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})" + if instruction_set.startswith('sve'): + result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})' + result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})' + result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' - result['rsqrt'] = None + result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" - result['width'] = width - result['intwidth'] = intwidth - result[data_type] = f'float{bits[data_type]}x{width}_t' - result['int'] = f'int{bits["int"]}x{bits[data_type]}_t' - result['bool'] = f'uint{bits[data_type]}x{width}_t' - result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"'] + result[data_type] = f'svfloat{bits[data_type]}_st' + result['int'] = f'svint{bits["int"]}_st' + result['bool'] = 'svbool_st' + + result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"'] + + result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})' + result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})' + result['blendv'] = f'svsel_f{bits[data_type]}' + '({2}, {1}, {0})' + result['any'] = f'svptest_any({predicate}, {{0}})' + result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}' + + result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] + else: + result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' + result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in + range(width)]) + ')' + result['makeVecConstInt'] = f'vdupq_n_s{bits["int"]}' + '({0})' + result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})' + + result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})" + + result[data_type] = f'float{bits[data_type]}x{width}_t' + result['int'] = f'int{bits["int"]}x{intwidth}_t' + result['bool'] = f'uint{bits[data_type]}x{width}_t' + + result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"'] - result['!='] = f'vmvnq_u{bits[data_type]}({result["=="]})' + result['!='] = f'vmvnq_u{bits[data_type]}({result["=="]})' - result['&'] = f'vandq_u{bits[data_type]}' + '({0}, {1})' - result['|'] = f'vorrq_u{bits[data_type]}' + '({0}, {1})' - result['blendv'] = f'vbslq_f{bits[data_type]}' + '({2}, {1}, {0})' - result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0' - result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff' + result['&'] = f'vandq_u{bits[data_type]}' + '({0}, {1})' + result['|'] = f'vorrq_u{bits[data_type]}' + '({0}, {1})' + result['blendv'] = f'vbslq_f{bits[data_type]}' + '({2}, {1}, {0})' + result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0' + result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff' - result['cachelineSize'] = 'cachelineSize()' - result['cachelineZero'] = 'cachelineZero((void*) {0})' + if bitwidth & (bitwidth - 1) == 0: + # only power-of-2 vector sizes will evenly divide a cacheline + result['cachelineSize'] = 'cachelineSize()' + result['cachelineZero'] = 'cachelineZero((void*) {0})' return result diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 9d7aa3cc30ba465fc97b97b75d7227776366ae51..5aabd83d67a1516d231856818482eae2ba6062dc 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -610,6 +610,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): is_integer = get_type_of_expression(arg[0]) == create_type("int") printed_args = [self._print(a) for a in arg] instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' + if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: + increments = np.array(arg)[1:] - np.array(arg)[:-1] + if len(set(increments)) == 1: + return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0]) return self.instruction_set[instruction].format(*printed_args) else: is_boolean = get_type_of_expression(arg) == create_type("bool") @@ -628,7 +632,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif expr.func == fast_inv_sqrt: result = self._scalarFallback('_print_Function', expr) if not result: - if self.instruction_set['rsqrt']: + if 'rsqrt' in self.instruction_set: return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) else: return f"({self._print(1 / sp.sqrt(expr.args[0]))})" diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index 9469dc59eb1b4d5d2189c138d42f4c1f233e5259..b552da0e9ac263721dfbf262c40c8cbb00352f7a 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -1,4 +1,6 @@ +import math import platform +from ctypes import CDLL from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86 from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm @@ -6,7 +8,7 @@ from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ def get_vector_instruction_set(data_type='double', instruction_set='avx'): - if instruction_set in ['neon', 'sve']: + if instruction_set in ['neon'] or instruction_set.startswith('sve'): return get_vector_instruction_set_arm(data_type, instruction_set) elif instruction_set in ['vsx']: return get_vector_instruction_set_ppc(data_type, instruction_set) @@ -47,6 +49,7 @@ def get_supported_instruction_sets(): required_avx_flags = {'avx', 'avx2'} required_avx512_flags = {'avx512f'} required_neon_flags = {'neon'} + required_sve_flags = {'sve'} flags = set(get_cpu_info()['flags']) if flags.issuperset(required_sse_flags): result.append("sse") @@ -56,6 +59,20 @@ def get_supported_instruction_sets(): result.append("avx512") if flags.issuperset(required_neon_flags): result.append("neon") + if flags.issuperset(required_sve_flags): + if platform.system() == 'Linux': + libc = CDLL('libc.so.6') + native_length = 8 * libc.prctl(51, 0, 0, 0, 0) # PR_SVE_GET_VL + if native_length < 0: + raise OSError("SVE length query failed") + pwr2_length = int(2**math.floor(math.log2(native_length))) + if pwr2_length % 256 == 0: + result.append(f"sve{pwr2_length//2}") + if native_length != pwr2_length: + result.append(f"sve{pwr2_length}") + result.append(f"sve{native_length}") + else: + result.append("sve") return result diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py index 86196aad7047c8f2bcc16a94b487537d3147d36d..78515e1a2d12dc9441fe2f0997cccca2fbd3626d 100644 --- a/pystencils/backends/x86_instruction_sets.py +++ b/pystencils/backends/x86_instruction_sets.py @@ -133,7 +133,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): 'float': "_" + pre, } - result['rsqrt'] = None bit_width = result['width'] * (64 if data_type == 'double' else 32) result['double'] = f"__m{bit_width}d" result['float'] = f"__m{bit_width}" diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 68ca259024e26cd2d97aa16bf5698e163c5c8373..fc9b810c30977825c503b560ba17244777bd7930 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -558,9 +558,9 @@ class ExtensionModuleCode: print(self._code_string, file=file) -def compile_module(code, code_hash, base_dir): +def compile_module(code, code_hash, base_dir, compile_flags=[]): compiler_config = get_compiler_config() - extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] + extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] + compile_flags if compiler_config['os'].lower() == 'windows': lib_suffix = '.pyd' @@ -620,12 +620,17 @@ def compile_and_load(ast, custom_backend=None): code.create_code_string(compiler_config['restrict_qualifier'], function_prefix) code_hash_str = code.get_hash_of_code() + compile_flags = [] + if ast.instruction_set and 'compile_flags' in ast.instruction_set: + compile_flags = ast.instruction_set['compile_flags'] + if cache_config['object_cache'] is False: with TemporaryDirectory() as base_dir: - lib_file = compile_module(code, code_hash_str, base_dir) + lib_file = compile_module(code, code_hash_str, base_dir, compile_flags=compile_flags) result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file) else: - lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache']) + lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'], + compile_flags=compile_flags) result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file) return KernelWrapper(result, ast.get_parameters(), ast) diff --git a/pystencils/include/arm_neon_helpers.h b/pystencils/include/arm_neon_helpers.h index 3d06d69bfc2dd866370e98968dacce3aaef3975c..a900001e793392fea66faf427873ce49eb2594d4 100644 --- a/pystencils/include/arm_neon_helpers.h +++ b/pystencils/include/arm_neon_helpers.h @@ -1,5 +1,17 @@ +#ifdef __ARM_NEON #include <arm_neon.h> +#endif +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 +#include <arm_sve.h> + +typedef svbool_t svbool_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +#endif + +#ifdef __ARM_NEON inline float32x4_t makeVec_f32(float a, float b, float c, float d) { alignas(16) float data[4] = {a, b, c, d}; @@ -17,6 +29,7 @@ inline int32x4_t makeVec_s32(int a, int b, int c, int d) alignas(16) int data[4] = {a, b, c, d}; return vld1q_s32(data); } +#endif inline void cachelineZero(void * p) { __asm__ volatile("dc zva, %0"::"r"(p)); diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 4d81d43e420f716ad3d07f4d58d68dcb127e2f5a..7684a4507f3fc0a532beb15632fb48f871640f21 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -15,6 +15,14 @@ #ifdef __ARM_NEON #include <arm_neon.h> #endif +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 +#include <arm_sve.h> +typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svuint32_t svuint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svuint64_t svuint64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +#endif #if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && !defined(__xlC__) #include <ppu_intrinsics.h> @@ -655,6 +663,158 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 } #endif + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 +QUALIFIERS void _philox4x32round(svuint32_st* ctr, svuint32_st* key) +{ + svuint32_st lo0 = svmul_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0)); + svuint32_st lo1 = svmul_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1)); + svuint32_st hi0 = svmulh_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0)); + svuint32_st hi1 = svmulh_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1)); + + ctr[0] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(svuint32_st* key) +{ + key[0] = svadd_u32_x(svptrue_b32(), key[0], svdup_u32(PHILOX_W32_0)); + key[1] = svadd_u32_x(svptrue_b32(), key[1], svdup_u32(PHILOX_W32_1)); +} + +template<bool high> +QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y) +{ + // convert 32 to 64 bit + if (high) + { + x = svzip2_u32(x, svdup_u32(0)); + y = svzip2_u32(y, svdup_u32(0)); + } + else + { + x = svzip1_u32(x, svdup_u32(0)); + y = svzip1_u32(y, svdup_u32(0)); + } + + // calculate z = x ^ y << (53 - 32)) + svuint64_st z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32); + z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z); + + // convert uint64 to double + svfloat64_st rs = svcvt_f64_u64_x(svptrue_b64(), z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0)); + + return rs; +} + + +QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3, + uint32 key0, uint32 key1, + svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) +{ + svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)}; + svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = svcvt_f32_u32_x(svptrue_b32(), ctr[0]); + rnd2 = svcvt_f32_u32_x(svptrue_b32(), ctr[1]); + rnd3 = svcvt_f32_u32_x(svptrue_b32(), ctr[2]); + rnd4 = svcvt_f32_u32_x(svptrue_b32(), ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = svmad_f32_x(svptrue_b32(), rnd3, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = svmad_f32_x(svptrue_b32(), rnd4, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0)); +} + + +QUALIFIERS void philox_double2(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) +{ + svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)}; + svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void philox_float4(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) +{ + svuint32_st ctr0v = svdup_u32(ctr0); + svuint32_st ctr2v = svdup_u32(ctr2); + svuint32_st ctr3v = svdup_u32(ctr3); + + philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_float4(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) +{ + philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) +{ + svuint32_st ctr0v = svdup_u32(ctr0); + svuint32_st ctr2v = svdup_u32(ctr2); + svuint32_st ctr3v = svdup_u32(ctr3); + + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1, svfloat64_st & rnd2) +{ + svuint32_st ctr0v = svdup_u32(ctr0); + svuint32_st ctr2v = svdup_u32(ctr2); + svuint32_st ctr3v = svdup_u32(ctr3); + + svfloat64_st ignore; + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); +} + +QUALIFIERS void philox_double2(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + svfloat64_st & rnd1, svfloat64_st & rnd2) +{ + philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2); +} +#endif + + #ifdef __AVX2__ QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key) { diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index b5396c9fb92891305b151a69c87e76bc688be778..cba58cfa24dac1b61ed1be5a6dab6ddc2e008483 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -27,7 +27,7 @@ if get_compiler_config()['os'] == 'windows': def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None): if target == 'gpu': pytest.importorskip('pycuda') - if instruction_sets and set(['neon', 'vsx']).intersection(instruction_sets) and rng == 'aesni': + if instruction_sets and (set(['neon', 'vsx']).intersection(instruction_sets) or any([iset.startswith('sve') for iset in instruction_sets])) and rng == 'aesni': pytest.xfail('AES not yet implemented for this architecture') if rng == 'aesni' and len(keys) == 2: keys *= 2 @@ -106,11 +106,11 @@ def test_rng_offsets(kind, vectorized): else: test = test_rng if kind == 'value': - test(instruction_sets[0] if vectorized else 'cpu', 'philox', 'float', 'float', t=8, + test(instruction_sets[-1] if vectorized else 'cpu', 'philox', 'float', 'float', t=8, offsets=(6, 7), keys=(5, 309)) elif kind == 'symbol': offsets = (TypedSymbol("x0", np.uint32), TypedSymbol("y0", np.uint32)) - test(instruction_sets[0] if vectorized else 'cpu', 'philox', 'float', 'float', t=8, + test(instruction_sets[-1] if vectorized else 'cpu', 'philox', 'float', 'float', t=8, offsets=offsets, offset_values=(6, 7), keys=(5, 309)) @@ -118,11 +118,11 @@ def test_rng_offsets(kind, vectorized): @pytest.mark.parametrize('rng', ('philox', 'aesni')) @pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double'))) def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None): - if target in ['neon', 'vsx'] and rng == 'aesni': + if (target in ['neon', 'vsx'] or target.startswith('sve')) and rng == 'aesni': pytest.xfail('AES not yet implemented for this architecture') cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target} - dh = ps.create_data_handling((17, 17), default_ghost_layers=0, default_target='cpu') + dh = ps.create_data_handling((131, 131), default_ghost_layers=0, default_target='cpu') f = dh.add_array("f", values_per_cell=4 if precision == 'float' else 2, dtype=np.float32 if dtype == 'float' else np.float64, alignment=True) dh.fill(f.name, 42.0) @@ -157,7 +157,7 @@ def test_rng_symbol(vectorized): pytest.skip("cannot detect CPU instruction set") else: cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, - 'instruction_set': instruction_sets[0]} + 'instruction_set': instruction_sets[-1]} else: cpu_vectorize_info = None @@ -189,7 +189,7 @@ def test_staggered(vectorized): pytest.skip("cannot detect CPU instruction set") pytest.importorskip('islpy') cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': False, - 'instruction_set': instruction_sets[0]} + 'instruction_set': instruction_sets[-1]} dh.fill(j.name, 867) dh.run_kernel(kernel, seed=5, time_step=309) diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 880a009a294569f004f7f8a0cba3a02b98ec5bd2..f668c6b81b03c8c2bd08990edddfc285d6bce0f1 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -216,8 +216,8 @@ def test_logical_operators(): def test_hardware_query(): - instruction_sets = get_supported_instruction_sets() - assert set(['sse', 'neon', 'vsx']).intersection(instruction_sets) + assert set(['sse', 'neon', 'vsx']).intersection(supported_instruction_sets) or \ + any([iset.startswith('sve') for iset in supported_instruction_sets]) def test_vectorised_pow():