diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 7684a4507f3fc0a532beb15632fb48f871640f21..84f0ba91edab6722847bf333d97e787ee07b6ce0 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -15,13 +15,8 @@ #ifdef __ARM_NEON #include <arm_neon.h> #endif -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 +#ifdef __ARM_FEATURE_SVE #include <arm_sve.h> -typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); -typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); -typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); -typedef svuint32_t svuint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); -typedef svuint64_t svuint64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); #endif #if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && !defined(__xlC__) @@ -52,6 +47,14 @@ typedef svuint64_t svuint64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_ typedef std::uint32_t uint32; typedef std::uint64_t uint64; +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 +typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); +#elif defined(__ARM_FEATURE_SVE) +typedef svfloat32_t svfloat32_st; +typedef svfloat64_t svfloat64_st; +#endif + QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip) { @@ -664,28 +667,28 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 #endif -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 -QUALIFIERS void _philox4x32round(svuint32_st* ctr, svuint32_st* key) +#if defined(__ARM_FEATURE_SVE) +QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) { - svuint32_st lo0 = svmul_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0)); - svuint32_st lo1 = svmul_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1)); - svuint32_st hi0 = svmulh_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0)); - svuint32_st hi1 = svmulh_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1)); - - ctr[0] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, ctr[1]), key[0]); - ctr[1] = lo1; - ctr[2] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, ctr[3]), key[1]); - ctr[3] = lo0; + svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); + svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1)); + svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); + svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1)); + + ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0))); + ctr = svset4_u32(ctr, 1, lo1); + ctr = svset4_u32(ctr, 2, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, svget4_u32(ctr, 3)), svget2_u32(key, 1))); + ctr = svset4_u32(ctr, 3, lo0); } -QUALIFIERS void _philox4x32bumpkey(svuint32_st* key) +QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key) { - key[0] = svadd_u32_x(svptrue_b32(), key[0], svdup_u32(PHILOX_W32_0)); - key[1] = svadd_u32_x(svptrue_b32(), key[1], svdup_u32(PHILOX_W32_1)); + key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0))); + key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1))); } template<bool high> -QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y) +QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y) { // convert 32 to 64 bit if (high) @@ -700,11 +703,11 @@ QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y) } // calculate z = x ^ y << (53 - 32)) - svuint64_st z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32); + svuint64_t z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32); z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z); // convert uint64 to double - svfloat64_st rs = svcvt_f64_u64_x(svptrue_b64(), z); + svfloat64_t rs = svcvt_f64_u64_x(svptrue_b64(), z); // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0)); @@ -712,12 +715,12 @@ QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y) } -QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3, +QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, uint32 key0, uint32 key1, svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) { - svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)}; - svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1)); + svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3); _philox4x32round(ctr, key); // 1 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 @@ -730,10 +733,10 @@ QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ct _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 // convert uint32 to float - rnd1 = svcvt_f32_u32_x(svptrue_b32(), ctr[0]); - rnd2 = svcvt_f32_u32_x(svptrue_b32(), ctr[1]); - rnd3 = svcvt_f32_u32_x(svptrue_b32(), ctr[2]); - rnd4 = svcvt_f32_u32_x(svptrue_b32(), ctr[3]); + rnd1 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 0)); + rnd2 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 1)); + rnd3 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 2)); + rnd4 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 3)); // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0)); rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0)); @@ -742,12 +745,12 @@ QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ct } -QUALIFIERS void philox_double2(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3, +QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3, uint32 key0, uint32 key1, svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) { - svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)}; - svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1)); + svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3); _philox4x32round(ctr, key); // 1 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 @@ -759,54 +762,54 @@ QUALIFIERS void philox_double2(svuint32_st ctr0, svuint32_st ctr1, svuint32_st c _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 - rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); - rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); - rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); - rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); + rnd1lo = _uniform_double_hq<false>(svget4_u32(ctr, 0), svget4_u32(ctr, 1)); + rnd1hi = _uniform_double_hq<true>(svget4_u32(ctr, 0), svget4_u32(ctr, 1)); + rnd2lo = _uniform_double_hq<false>(svget4_u32(ctr, 2), svget4_u32(ctr, 3)); + rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3)); } -QUALIFIERS void philox_float4(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3, +QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) { - svuint32_st ctr0v = svdup_u32(ctr0); - svuint32_st ctr2v = svdup_u32(ctr2); - svuint32_st ctr3v = svdup_u32(ctr3); + svuint32_t ctr0v = svdup_u32(ctr0); + svuint32_t ctr2v = svdup_u32(ctr2); + svuint32_t ctr3v = svdup_u32(ctr3); philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); } -QUALIFIERS void philox_float4(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3, +QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4) { philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4); } -QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3, +QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi) { - svuint32_st ctr0v = svdup_u32(ctr0); - svuint32_st ctr2v = svdup_u32(ctr2); - svuint32_st ctr3v = svdup_u32(ctr3); + svuint32_t ctr0v = svdup_u32(ctr0); + svuint32_t ctr2v = svdup_u32(ctr2); + svuint32_t ctr3v = svdup_u32(ctr3); philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); } -QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3, +QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, svfloat64_st & rnd1, svfloat64_st & rnd2) { - svuint32_st ctr0v = svdup_u32(ctr0); - svuint32_st ctr2v = svdup_u32(ctr2); - svuint32_st ctr3v = svdup_u32(ctr3); + svuint32_t ctr0v = svdup_u32(ctr0); + svuint32_t ctr2v = svdup_u32(ctr2); + svuint32_t ctr3v = svdup_u32(ctr3); svfloat64_st ignore; philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); } -QUALIFIERS void philox_double2(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3, +QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, svfloat64_st & rnd1, svfloat64_st & rnd2) {