diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h index a480883e5acbc55a8d1bf7bd90719600907abe67..09327f27b8bc0b3cdd0b16cb8a64e3237b555797 100644 --- a/pystencils/include/aesni_rand.h +++ b/pystencils/include/aesni_rand.h @@ -2,7 +2,11 @@ #error AES-NI and SSE2 need to be enabled #endif -#include <x86intrin.h> +#include <emmintrin.h> // SSE2 +#include <wmmintrin.h> // AES +#ifdef __AVX512VL__ +#include <immintrin.h> // AVX* +#endif #include <cstdint> #define QUALIFIERS inline @@ -14,22 +18,22 @@ typedef std::uint64_t uint64; QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) { __m128i x = _mm_xor_si128(k, in); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenc_si128(x, k); - x = _mm_aesenclast_si128(x, k); + x = _mm_aesenc_si128(x, k); // 1 + x = _mm_aesenc_si128(x, k); // 2 + x = _mm_aesenc_si128(x, k); // 3 + x = _mm_aesenc_si128(x, k); // 4 + x = _mm_aesenc_si128(x, k); // 5 + x = _mm_aesenc_si128(x, k); // 6 + x = _mm_aesenc_si128(x, k); // 7 + x = _mm_aesenc_si128(x, k); // 8 + x = _mm_aesenc_si128(x, k); // 9 + x = _mm_aesenclast_si128(x, k); // 10 return x; } QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) { -#ifdef __AVX512F__ +#ifdef __AVX512VL__ return _mm_cvtepu32_ps(v); #else __m128i v2 = _mm_srli_epi32(v, 1); @@ -40,13 +44,14 @@ QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) #endif } -QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i v) +QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) { -#ifdef __AVX512F__ - return _mm_cvtepu64_pd(v); +#ifdef __AVX512VL__ + return _mm_cvtepu64_pd(x); #else - #warning need to implement _my_cvtepu64_pd - return (__m128d) v; + uint64 r[2]; + _mm_storeu_si128((__m128i*)r, x); + return _mm_set_pd((double)r[1], (double)r[0]); #endif } @@ -55,25 +60,27 @@ QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 uint32 key0, uint32 key1, uint32 key2, uint32 key3, double & rnd1, double & rnd2) { + // pack input and call AES __m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0); __m128i k128 = _mm_set_epi32(key3, key2, key1, key0); c128 = aesni1xm128i(c128, k128); - uint32 r[4]; - _mm_storeu_si128((__m128i*)&r[0], c128); - __m128i x = _mm_set_epi64x((uint64) r[2], (uint64) r[0]); - __m128i y = _mm_set_epi64x((uint64) r[3], (uint64) r[1]); + // convert 32 to 64 bit and put 0th and 2nd element into x, 1st and 3rd element into y + __m128i x = _mm_and_si128(c128, _mm_set_epi32(0, 0xffffffff, 0, 0xffffffff)); + __m128i y = _mm_and_si128(c128, _mm_set_epi32(0xffffffff, 0, 0xffffffff, 0)); + y = _mm_srli_si128(y, 4); - __m128i cnt = _mm_set_epi64x(53 - 32, 53 - 32); - y = _mm_sll_epi64(y, cnt); - __m128i z = _mm_xor_si128(x, y); + // calculate z = x ^ y << (53 - 32)) + __m128i z = _mm_sll_epi64(y, _mm_set_epi64x(53 - 32, 53 - 32)); + z = _mm_xor_si128(x, z); + // convert uint64 to double __m128d rs = _my_cvtepu64_pd(z); - const __m128d tp53 = _mm_set_pd1(TWOPOW53_INV_DOUBLE); - const __m128d tp54 = _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0); - rs = _mm_mul_pd(rs, tp53); - rs = _mm_add_pd(rs, tp54); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = _mm_mul_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE)); + rs = _mm_add_pd(rs, _mm_set_pd1(TWOPOW53_INV_DOUBLE/2.0)); + // store result double rr[2]; _mm_storeu_pd(rr, rs); rnd1 = rr[0]; @@ -85,16 +92,18 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, uint32 key2, uint32 key3, float & rnd1, float & rnd2, float & rnd3, float & rnd4) { + // pack input and call AES __m128i c128 = _mm_set_epi32(ctr3, ctr2, ctr1, ctr0); __m128i k128 = _mm_set_epi32(key3, key2, key1, key0); c128 = aesni1xm128i(c128, k128); + // convert uint32 to float __m128 rs = _my_cvtepu32_ps(c128); - const __m128 tp32 = _mm_set_ps1(TWOPOW32_INV_FLOAT); - const __m128 tp33 = _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f); - rs = _mm_mul_ps(rs, tp32); - rs = _mm_add_ps(rs, tp33); + // calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rs = _mm_mul_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT)); + rs = _mm_add_ps(rs, _mm_set_ps1(TWOPOW32_INV_FLOAT/2.0f)); + // store result float r[4]; _mm_storeu_ps(r, rs); rnd1 = r[0]; diff --git a/pystencils/rng.py b/pystencils/rng.py index 81e33419fdc518613ac79cd15175fe923748e378..4341a0b3906cb9505b1ffbb4c6070e9ea09aee60 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -21,7 +21,7 @@ def _get_rng_template(name, data_type, num_vars): def _get_rng_code(template, dialect, vector_instruction_set, time_step, offsets, keys, dim, result_symbols): parameters = [time_step] + [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] - for i in range(dim)] + [0] * (3-dim) + list(keys) + for i in range(dim)] + [0] * (3 - dim) + list(keys) if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): return template.format(parameters=', '.join(str(p) for p in parameters), @@ -67,7 +67,7 @@ class RNGBase(CustomCodeNode): def get_code(self, dialect, vector_instruction_set): template = _get_rng_template(self._name, self._data_type, self._num_vars) return _get_rng_code(template, dialect, vector_instruction_set, - self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) + self._time_step, self._offsets, self.keys, self._dim, self.result_symbols) def __repr__(self): return (", ".join(['{}'] * self._num_vars) + " <- {}RNG").format(*self.result_symbols, self._name.capitalize()) diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index f55d028d097c66dca0f6da47329270656f8b6ce1..473aa3d1215449f51e70aa3b09f2c605ed13c313 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -76,12 +76,6 @@ def test_aesni_double(): arr = dh.gather_array('f') assert np.logical_and(arr <= 1.0, arr >= 0).all() - #x = aesni_reference[:,:,0::2] - #y = aesni_reference[:,:,1::2] - #z = x ^ y << (53 - 32) - #double_reference = z * 2.**-53 + 2.**-54 - #assert(np.allclose(arr, double_reference, rtol=0, atol=np.finfo(np.float64).eps)) - def test_aesni_float(): dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu") @@ -97,7 +91,4 @@ def test_aesni_float(): dh.run_kernel(kernel, time_step=124) dh.all_to_cpu() arr = dh.gather_array('f') - assert np.logical_and(arr <= 1.0, arr >= 0).all() - - #float_reference = aesni_reference * 2.**-32 + 2.**-33 - #assert(np.allclose(arr, float_reference, rtol=0, atol=np.finfo(np.float32).eps)) \ No newline at end of file + assert np.logical_and(arr <= 1.0, arr >= 0).all() \ No newline at end of file