Skip to content
Snippets Groups Projects
Commit 170a7717 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

clean up AES-NI RNG

parent fa0a09a5
No related merge requests found
......@@ -2,7 +2,11 @@
#error AES-NI and SSE2 need to be enabled
#endif
#include <x86intrin.h>
#include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES
#ifdef __AVX512VL__
#include <immintrin.h> // AVX*
#endif
#include <cstdint>
#define QUALIFIERS inline
......@@ -14,22 +18,22 @@ typedef std::uint64_t uint64;
QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) {
__m128i x = _mm_xor_si128(k, in);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenc_si128(x, k);
x = _mm_aesenclast_si128(x, k);
x = _mm_aesenc_si128(x, k); // 1
x = _mm_aesenc_si128(x, k); // 2
x = _mm_aesenc_si128(x, k); // 3
x = _mm_aesenc_si128(x, k); // 4
x = _mm_aesenc_si128(x, k); // 5
x = _mm_aesenc_si128(x, k); // 6
x = _mm_aesenc_si128(x, k); // 7
x = _mm_aesenc_si128(x, k); // 8
x = _mm_aesenc_si128(x, k); // 9
x = _mm_aesenclast_si128(x, k); // 10
return x;
}
QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
{
#ifdef __AVX512F__
#ifdef __AVX512VL__
return _mm_cvtepu32_ps(v);
#else
__m128i v2 = _mm_srli_epi32(v, 1);
......@@ -40,13 +44,14 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v)
#endif
}
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i v)
QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x)
{
#ifdef __AVX512F__
return _mm_cvtepu64_pd(v);
#ifdef __AVX512VL__
return _mm_cvtepu64_pd(x);
#else
#warning need to implement _my_cvtepu64_pd
return (__m128d) v;
uint64 r[2];
_mm_storeu_si128((__m128i*)r, x);
return _mm_set_pd((double)r[1], (double)r[0]);
#endif
}
......@@ -55,25 +60,27 @@ QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
double & rnd1, double & rnd2)
{
// pack input and call AES
__m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0);
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
c128 = aesni1xm128i(c128, k128);
uint32 r[4];
_mm_storeu_si128((__m128i*)&r[0], c128);
__m128i x = _mm_set_epi64x((uint64) r[2], (uint64) r[0]);
__m128i y = _mm_set_epi64x((uint64) r[3], (uint64) r[1]);
// convert 32 to 64 bit and put 0th and 2nd element into x, 1st and 3rd element into y
__m128i x = _mm_and_si128(c128, _mm_set_epi32(0, 0xffffffff, 0, 0xffffffff));
__m128i y = _mm_and_si128(c128, _mm_set_epi32(0xffffffff, 0, 0xffffffff, 0));
y = _mm_srli_si128(y, 4);
__m128i cnt = _mm_set_epi64x(53 - 32, 53 - 32);
y = _mm_sll_epi64(y, cnt);
__m128i z = _mm_xor_si128(x, y);
// calculate z = x ^ y << (53 - 32))
__m128i z = _mm_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32));
z = _mm_xor_si128(x, z);
// convert uint64 to double
__m128d rs = _my_cvtepu64_pd(z);
const __m128d tp53 = _mm_set_pd1(TWOPOW53_INV_DOUBLE);
const __m128d tp54 = _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0);
rs = _mm_mul_pd(rs, tp53);
rs = _mm_add_pd(rs, tp54);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = _mm_mul_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE));
rs = _mm_add_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0));
// store result
double rr[2];
_mm_storeu_pd(rr, rs);
rnd1 = rr[0];
......@@ -85,16 +92,18 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float & rnd1, float & rnd2, float & rnd3, float & rnd4)
{
// pack input and call AES
__m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0);
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
c128 = aesni1xm128i(c128, k128);
// convert uint32 to float
__m128 rs = _my_cvtepu32_ps(c128);
const __m128 tp32 = _mm_set_ps1(TWOPOW32_INV_FLOAT);
const __m128 tp33 = _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f);
rs = _mm_mul_ps(rs, tp32);
rs = _mm_add_ps(rs, tp33);
// calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rs = _mm_mul_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT));
rs = _mm_add_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f));
// store result
float r[4];
_mm_storeu_ps(r, rs);
rnd1 = r[0];
......
......@@ -21,7 +21,7 @@ def _get_rng_template(name, data_type, num_vars):
def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols):
parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i]
for i in range(dim)] + [0] * (3-dim) + list(keys)
for i in range(dim)] + [0] * (3 - dim) + list(keys)
if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None):
return template.format(parameters=', '.join(str(p) for p in parameters),
......@@ -67,7 +67,7 @@ class RNGBase(CustomCodeNode):
def get_code(self, dialect, vector_instruction_set):
template = _get_rng_template(self._name, self._data_type, self._num_vars)
return _get_rng_code(template, dialect, vector_instruction_set,
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
self._time_step, self._offsets, self.keys, self._dim, self.result_symbols)
def __repr__(self):
return (", ".join(['{}'] * self._num_vars) + " <- {}RNG").format(*self.result_symbols, self._name.capitalize())
......
......@@ -76,12 +76,6 @@ def test_aesni_double():
arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all()
#x = aesni_reference[:,:,0::2]
#y = aesni_reference[:,:,1::2]
#z = x ^ y << (53 - 32)
#double_reference = z * 2.**-53 + 2.**-54
#assert(np.allclose(arr, double_reference, rtol=0, atol=np.finfo(np.float64).eps))
def test_aesni_float():
dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu")
......@@ -97,7 +91,4 @@ def test_aesni_float():
dh.run_kernel(kernel, time_step=124)
dh.all_to_cpu()
arr = dh.gather_array('f')
assert np.logical_and(arr <= 1.0, arr >= 0).all()
#float_reference = aesni_reference * 2.**-32 + 2.**-33
#assert(np.allclose(arr, float_reference, rtol=0, atol=np.finfo(np.float32).eps))
\ No newline at end of file
assert np.logical_and(arr <= 1.0, arr >= 0).all()
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment