Commit 71b8767b authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'philox' into 'master'

AES-NI vectorization improvements

See merge request pycodegen/pystencils!46
parents d0cd169d 69f562d8
Pipeline #18204 passed with stage
in 5 minutes and 31 seconds
#if !defined(__AES__) || !defined(__SSE2__) #if !defined(__AES__) || !defined(__SSE4_1__)
#error AES-NI and SSE2 need to be enabled #error AES-NI and SSE4.1 need to be enabled
#endif #endif
#include <emmintrin.h> // SSE2 #include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES #include <wmmintrin.h> // AES
#ifdef __AVX512VL__ #ifdef __AVX512VL__
#include <immintrin.h> // AVX* #include <immintrin.h> // AVX*
#else
#include <smmintrin.h> // SSE4
#ifdef __FMA__
#include <immintrin.h> // FMA
#endif
#endif #endif
#include <cstdint> #include <cstdint>
...@@ -44,14 +49,19 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) ...@@ -44,14 +49,19 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
#endif #endif
} }
#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5
__attribute__((optimize("no-associative-math")))
#endif
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{ {
#ifdef __AVX512VL__ #ifdef __AVX512VL__
return _mm_cvtepu64_pd(x); return _mm_cvtepu64_pd(x);
#else #else
uint64 r[2]; __m128i xH = _mm_srli_epi64(x, 32);
_mm_storeu_si128((__m128i*)r, x); xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84
return _mm_set_pd((double)r[1], (double)r[0]); __m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52
__m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52
return _mm_add_pd(f, _mm_castsi128_pd(xL));
#endif #endif
} }
...@@ -71,18 +81,22 @@ QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 ...@@ -71,18 +81,22 @@ QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
y = _mm_srli_si128(y, 4); y = _mm_srli_si128(y, 4);
// calculate z = x ^ y << (53 - 32)) // calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); __m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
z = _mm_xor_si128(x, z); z = _mm_xor_si128(x, z);
// convert uint64 to double // convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z); __m128d rs = _my_cvtepu64_pd(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm_mul_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE)); #ifdef __FMA__
rs = _mm_add_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#else
rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0));
#endif
// store result // store result
double rr[2]; alignas(16) double rr[2];
_mm_storeu_pd(rr, rs); _mm_store_pd(rr, rs);
rnd1 = rr[0]; rnd1 = rr[0];
rnd2 = rr[1]; rnd2 = rr[1];
} }
...@@ -100,14 +114,19 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, ...@@ -100,14 +114,19 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
// convert uint32 to float // convert uint32 to float
__m128 rs = _my_cvtepu32_ps(c128); __m128 rs = _my_cvtepu32_ps(c128);
// calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) // calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rs = _mm_mul_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT)); #ifdef __FMA__
rs = _mm_add_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); rs = _mm_fmadd_ps(rs, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#else
rs = _mm_mul_ps(rs, _mm_set1_ps(TWOPOW32_INV_FLOAT));
rs = _mm_add_ps(rs, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f));
#endif
// store result // store result
float r[4]; alignas(16) float r[4];
_mm_storeu_ps(r, rs); _mm_store_ps(r, rs);
rnd1 = r[0]; rnd1 = r[0];
rnd2 = r[1]; rnd2 = r[1];
rnd3 = r[2]; rnd3 = r[2];
rnd4 = r[3]; rnd4 = r[3];
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment