Commit 896b4192 authored by Michael Kuron's avatar Michael Kuron
Browse files

Cherry-pick updated SVE Philox

parent 7e6be86e
......@@ -15,13 +15,8 @@
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
#ifdef __ARM_FEATURE_SVE
#include <arm_sve.h>
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svuint32_t svuint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svuint64_t svuint64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
#endif
#if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && !defined(__xlC__)
......@@ -52,6 +47,14 @@ typedef svuint64_t svuint64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_
typedef std::uint32_t uint32;
typedef std::uint64_t uint64;
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
#elif defined(__ARM_FEATURE_SVE)
typedef svfloat32_t svfloat32_st;
typedef svfloat64_t svfloat64_st;
#endif
QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip)
{
......@@ -664,28 +667,28 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
QUALIFIERS void _philox4x32round(svuint32_st* ctr, svuint32_st* key)
#if defined(__ARM_FEATURE_SVE)
QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key)
{
svuint32_st lo0 = svmul_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0));
svuint32_st lo1 = svmul_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1));
svuint32_st hi0 = svmulh_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0));
svuint32_st hi1 = svmulh_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1));
ctr[0] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, ctr[1]), key[0]);
ctr[1] = lo1;
ctr[2] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, ctr[3]), key[1]);
ctr[3] = lo0;
svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0)));
ctr = svset4_u32(ctr, 1, lo1);
ctr = svset4_u32(ctr, 2, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, svget4_u32(ctr, 3)), svget2_u32(key, 1)));
ctr = svset4_u32(ctr, 3, lo0);
}
QUALIFIERS void _philox4x32bumpkey(svuint32_st* key)
QUALIFIERS void _philox4x32bumpkey(svuint32x2_t & key)
{
key[0] = svadd_u32_x(svptrue_b32(), key[0], svdup_u32(PHILOX_W32_0));
key[1] = svadd_u32_x(svptrue_b32(), key[1], svdup_u32(PHILOX_W32_1));
key = svset2_u32(key, 0, svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0)));
key = svset2_u32(key, 1, svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1)));
}
template<bool high>
QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y)
QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y)
{
// convert 32 to 64 bit
if (high)
......@@ -700,11 +703,11 @@ QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y)
}
// calculate z = x ^ y << (53 - 32))
svuint64_st z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
svuint64_t z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z);
// convert uint64 to double
svfloat64_st rs = svcvt_f64_u64_x(svptrue_b64(), z);
svfloat64_t rs = svcvt_f64_u64_x(svptrue_b64(), z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0));
......@@ -712,12 +715,12 @@ QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y)
}
QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3,
QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)};
svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3};
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
......@@ -730,10 +733,10 @@ QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ct
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
// convert uint32 to float
rnd1 = svcvt_f32_u32_x(svptrue_b32(), ctr[0]);
rnd2 = svcvt_f32_u32_x(svptrue_b32(), ctr[1]);
rnd3 = svcvt_f32_u32_x(svptrue_b32(), ctr[2]);
rnd4 = svcvt_f32_u32_x(svptrue_b32(), ctr[3]);
rnd1 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 0));
rnd2 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 1));
rnd3 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 2));
rnd4 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 3));
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
......@@ -742,12 +745,12 @@ QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ct
}
QUALIFIERS void philox_double2(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3,
QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)};
svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3};
svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
_philox4x32round(ctr, key); // 1
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3
......@@ -759,54 +762,54 @@ QUALIFIERS void philox_double2(svuint32_st ctr0, svuint32_st ctr1, svuint32_st c
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9
_philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
rnd1lo = _uniform_double_hq<false>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd1hi = _uniform_double_hq<true>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd2lo = _uniform_double_hq<false>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
}
QUALIFIERS void philox_float4(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3,
QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32_st ctr0v = svdup_u32(ctr0);
svuint32_st ctr2v = svdup_u32(ctr2);
svuint32_st ctr3v = svdup_u32(ctr3);
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_float4(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3,
QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3,
QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32_st ctr0v = svdup_u32(ctr0);
svuint32_st ctr2v = svdup_u32(ctr2);
svuint32_st ctr3v = svdup_u32(ctr3);
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3,
QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
svuint32_st ctr0v = svdup_u32(ctr0);
svuint32_st ctr2v = svdup_u32(ctr2);
svuint32_st ctr3v = svdup_u32(ctr3);
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
svfloat64_st ignore;
philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void philox_double2(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3,
QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment