diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 654b92b82f8c5e225f26abd5a5c14ce0b113a43e..e489f3cc2e55b32418351bc24e1ebcbc5484f1f2 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -11,7 +11,7 @@ from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, - reinterpret_cast_func, vector_memory_access) + reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, @@ -134,7 +134,7 @@ class CustomCodeNode(Node): self._symbols_defined = set(symbols_defined) self.headers = [] - def get_code(self, dialect, vector_instruction_set): + def get_code(self, dialect, vector_instruction_set, print_arg): return self._code @property @@ -297,7 +297,7 @@ class CBackend: return "continue;" def _print_CustomCodeNode(self, node): - return node.get_code(self._dialect, self._vector_instruction_set) + return node.get_code(self._dialect, self._vector_instruction_set, print_arg=self.sympy_printer._print) def _print_SourceCodeComment(self, node): return f"/* {node.text } */" @@ -548,12 +548,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if type(data_type) is VectorType: if isinstance(arg, sp.Tuple): is_boolean = get_type_of_expression(arg[0]) == create_type("bool") + is_integer = get_type_of_expression(arg[0]) == create_type("int") printed_args = [self._print(a) for a in arg] - instruction = 'makeVecBool' if is_boolean else 'makeVec' + instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' return self.instruction_set[instruction].format(*printed_args) else: is_boolean = get_type_of_expression(arg) == create_type("bool") - instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst' + is_integer = get_type_of_expression(arg) == create_type("int") or \ + (isinstance(arg, TypedSymbol) and arg.dtype.is_int()) + instruction = 'makeVecConstBool' if is_boolean else \ + 'makeVecConstInt' if is_integer else 'makeVecConst' return self.instruction_set[instruction].format(self._print(arg)) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) @@ -609,12 +613,27 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return result def _print_Add(self, expr, order=None): - result = self._scalarFallback('_print_Add', expr) + try: + result = self._scalarFallback('_print_Add', expr) + except Exception: + result = None if result: return result + args = expr.args + + # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization + suffix = "" + if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer) + or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]): + dtype = set([e.dtype for e in args if type(e) is cast_func]) + assert len(dtype) == 1 + dtype = dtype.pop() + args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e + for e in args] + suffix = "int" summands = [] - for term in expr.args: + for term in args: if term.func == sp.Mul: sign, t = self._print_Mul(term, inside_add=True) else: @@ -630,7 +649,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): assert len(summands) >= 2 processed = summands[0].term for summand in summands[1:]: - func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] + func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix] processed = func.format(processed, summand.term) return processed diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index c6290aa45b2ac0451c671d11ace2fe1d40d86264..44284d6bbbfd3fec99306408c22c5d34493eb508 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -18,7 +18,7 @@ def get_supported_instruction_sets(): result = [] required_sse_flags = {'sse', 'sse2', 'ssse3', 'sse4_1', 'sse4_2'} - required_avx_flags = {'avx'} + required_avx_flags = {'avx', 'avx2'} required_avx512_flags = {'avx512f'} required_neon_flags = {'neon'} flags = set(get_cpu_info()['flags']) diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py index 349c190e252f89cba9f04c8f6b338933dfa6b8e1..57164d6789619903ddf133f2dc78848f8fccb112 100644 --- a/pystencils/backends/x86_instruction_sets.py +++ b/pystencils/backends/x86_instruction_sets.py @@ -1,7 +1,7 @@ def get_argument_string(intrinsic_id, width, function_shortcut): - if intrinsic_id == 'makeVecConst': + if intrinsic_id == 'makeVecConst' or intrinsic_id == 'makeVecConstInt': arg_string = f"({','.join(['{0}'] * width)})" - elif intrinsic_id == 'makeVec': + elif intrinsic_id == 'makeVec' or intrinsic_id == 'makeVecInt': params = ["{" + str(i) + "}" for i in reversed(range(width))] arg_string = f"({','.join(params)})" elif intrinsic_id == 'makeVecBool': @@ -49,6 +49,8 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): 'makeVec': 'set[]', 'makeVecBool': 'set[]', 'makeVecConstBool': 'set[]', + 'makeVecInt': 'set[]', + 'makeVecConstInt': 'set[]', 'loadU': 'loadu[0]', 'loadA': 'load[0]', @@ -86,6 +88,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): suffix = { 'double': 'pd', 'float': 'ps', + 'int': 'epi32' } prefix = { 'sse': '_mm', @@ -96,22 +99,30 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): width = { ("double", "sse"): 2, ("float", "sse"): 4, + ("int", "sse"): 4, ("double", "avx"): 4, ("float", "avx"): 8, + ("int", "avx"): 8, ("double", "avx512"): 8, ("float", "avx512"): 16, + ("int", "avx512"): 16, } result = { 'width': width[(data_type, instruction_set)], + 'intwidth': width[('int', instruction_set)] } pre = prefix[instruction_set] - suf = suffix[data_type] for intrinsic_id, function_shortcut in base_names.items(): function_shortcut = function_shortcut.strip() name = function_shortcut[:function_shortcut.index('[')] - arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut) + if 'Int' in intrinsic_id: + suf = suffix['int'] + arg_string = get_argument_string(intrinsic_id, result['intwidth'], function_shortcut) + else: + suf = suffix[data_type] + arg_string = get_argument_string(intrinsic_id, result['width'], function_shortcut) mask_suffix = '_mask' if instruction_set == 'avx512' and intrinsic_id in comparisons.keys() else '' result[intrinsic_id] = pre + "_" + name + "_" + suf + mask_suffix + arg_string @@ -151,4 +162,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): if instruction_set == 'avx' and data_type == 'float': result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" + result['+int'] = f"{pre}_add_{suffix['int']}({{0}}, {{1}})" + return result diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index cf51456569b47e5fea1dfd7698f09e2416fe70b1..6ecc87284ad362264fb5dcb89298dc492d6a12d5 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -73,11 +73,30 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', vector_width = vector_is['width'] kernel_ast.instruction_set = vector_is + vectorize_rng(kernel_ast, vector_width) vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, assume_sufficient_line_padding) insert_vector_casts(kernel_ast) +def vectorize_rng(kernel_ast, vector_width): + """Replace scalar result symbols on RNG nodes with vectorial ones""" + from pystencils.rng import RNGBase + subst = {} + + def visit_node(node): + for arg in node.args: + if isinstance(arg, RNGBase): + new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width)) + for s in arg.result_symbols] + subst.update({s[0]: s[1] for s in zip(arg.result_symbols, new_result_symbols)}) + arg._symbols_defined = set(new_result_symbols) + else: + visit_node(arg) + visit_node(kernel_ast) + fast_subs(kernel_ast.body, subst, skip=lambda e: isinstance(e, RNGBase)) + + def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, assume_sufficient_line_padding): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" @@ -129,8 +148,10 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_node.step = vector_width loop_node.subs(substitutions) - vector_loop_counter = cast_func(tuple(loop_counter_symbol + i for i in range(vector_width)), - VectorType(loop_counter_symbol.dtype, vector_width)) + vector_int_width = ast_node.instruction_set['intwidth'] + vector_loop_counter = cast_func((loop_counter_symbol,) * vector_int_width, + VectorType(loop_counter_symbol.dtype, vector_int_width)) + \ + cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width)) fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) @@ -178,7 +199,10 @@ def insert_vector_casts(ast_node): return expr elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0]) - pw = sp.Piecewise((-1 * new_arg, new_arg < 0), (new_arg, True)) + base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \ + else get_type_of_expression(expr.args[0]) + pw = sp.Piecewise((base_type.numpy_dtype.type(-1) * new_arg, new_arg < base_type.numpy_dtype.type(0)), + (new_arg, True)) return visit_expr(pw) elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction): new_args = [visit_expr(a) for a in expr.args] diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 365ef8aa787733d61b4e8b5367eda1fe6daac51d..7300f01ea47f6693a2915b2cb063f907e46d1c32 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -119,7 +119,7 @@ class cast_func(sp.Function): # rhs = cast_func(0, 'int') # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans # -> thus a separate class boolean_cast_func is introduced - if isinstance(expr, Boolean): + if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)): cls = boolean_cast_func return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs) @@ -697,7 +697,7 @@ class VectorType(Type): if self.instruction_set is None: return "%s[%d]" % (self.base_type, self.width) else: - if self.base_type == create_type("int64"): + if self.base_type == create_type("int64") or self.base_type == create_type("int32"): return self.instruction_set['int'] elif self.base_type == create_type("float64"): return self.instruction_set['double'] diff --git a/pystencils/field.py b/pystencils/field.py index 78c420c8c3b9174ea0ec619c5a805659223caf6e..fdc587e60523b7547e15859ed4b0ab17643f307e 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -958,8 +958,6 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 if not alignment: res = np.empty(shape, order='c', **kwargs) else: - if alignment is True: - alignment = 8 * 4 res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs) for a, b in reversed(swaps): diff --git a/pystencils/include/aesni_rand.h b/pystencils/include/aesni_rand.h index e91101e4073901fb15081e1e0463b2f5116e40a4..4206e37f634e01586223e1412ce7836e2499f65f 100644 --- a/pystencils/include/aesni_rand.h +++ b/pystencils/include/aesni_rand.h @@ -1,6 +1,6 @@ #include <emmintrin.h> // SSE2 #include <wmmintrin.h> // AES -#ifdef __AVX512VL__ +#ifdef __AVX__ #include <immintrin.h> // AVX* #else #include <smmintrin.h> // SSE4 @@ -9,56 +9,108 @@ #endif #endif #include <cstdint> +#include <array> +#include <map> #define QUALIFIERS inline #define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16) #define TWOPOW32_INV_FLOAT (2.3283064e-10f) +#include "myintrin.h" + typedef std::uint32_t uint32; typedef std::uint64_t uint64; -QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k) { - __m128i x = _mm_xor_si128(k, in); - x = _mm_aesenc_si128(x, k); // 1 - x = _mm_aesenc_si128(x, k); // 2 - x = _mm_aesenc_si128(x, k); // 3 - x = _mm_aesenc_si128(x, k); // 4 - x = _mm_aesenc_si128(x, k); // 5 - x = _mm_aesenc_si128(x, k); // 6 - x = _mm_aesenc_si128(x, k); // 7 - x = _mm_aesenc_si128(x, k); // 8 - x = _mm_aesenc_si128(x, k); // 9 - x = _mm_aesenclast_si128(x, k); // 10 - return x; +#if defined(__AES__) || defined(_MSC_VER) +QUALIFIERS __m128i aesni_keygen_assist(__m128i temp1, __m128i temp2) { + __m128i temp3; + temp2 = _mm_shuffle_epi32(temp2 ,0xff); + temp3 = _mm_slli_si128(temp1, 0x4); + temp1 = _mm_xor_si128(temp1, temp3); + temp3 = _mm_slli_si128(temp3, 0x4); + temp1 = _mm_xor_si128(temp1, temp3); + temp3 = _mm_slli_si128(temp3, 0x4); + temp1 = _mm_xor_si128(temp1, temp3); + temp1 = _mm_xor_si128(temp1, temp2); + return temp1; } -QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) -{ -#ifdef __AVX512VL__ - return _mm_cvtepu32_ps(v); -#else - __m128i v2 = _mm_srli_epi32(v, 1); - __m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1)); - __m128 v2f = _mm_cvtepi32_ps(v2); - __m128 v1f = _mm_cvtepi32_ps(v1); - return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f); -#endif +QUALIFIERS std::array<__m128i,11> aesni_keygen(__m128i k) { + std::array<__m128i,11> rk; + __m128i tmp; + + rk[0] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x1); + k = aesni_keygen_assist(k, tmp); + rk[1] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x2); + k = aesni_keygen_assist(k, tmp); + rk[2] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x4); + k = aesni_keygen_assist(k, tmp); + rk[3] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x8); + k = aesni_keygen_assist(k, tmp); + rk[4] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x10); + k = aesni_keygen_assist(k, tmp); + rk[5] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x20); + k = aesni_keygen_assist(k, tmp); + rk[6] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x40); + k = aesni_keygen_assist(k, tmp); + rk[7] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x80); + k = aesni_keygen_assist(k, tmp); + rk[8] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x1b); + k = aesni_keygen_assist(k, tmp); + rk[9] = k; + + tmp = _mm_aeskeygenassist_si128(k, 0x36); + k = aesni_keygen_assist(k, tmp); + rk[10] = k; + + return rk; } -#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 -__attribute__((optimize("no-associative-math"))) -#endif -QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) -{ -#ifdef __AVX512VL__ - return _mm_cvtepu64_pd(x); -#else - __m128i xH = _mm_srli_epi64(x, 32); - xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84 - __m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52 - __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52 - return _mm_add_pd(f, _mm_castsi128_pd(xL)); -#endif +QUALIFIERS const std::array<__m128i,11> & aesni_roundkeys(const __m128i & k128) { + std::array<uint32,4> a; + _mm_storeu_si128((__m128i*) a.data(), k128); + + static std::map<std::array<uint32,4>, std::array<__m128i,11>> roundkeys; + + if(roundkeys.find(a) == roundkeys.end()) { + auto rk = aesni_keygen(k128); + roundkeys[a] = rk; + } + return roundkeys[a]; +} + +QUALIFIERS __m128i aesni1xm128i(const __m128i & in, const __m128i & k0) { + auto k = aesni_roundkeys(k0); + __m128i x = _mm_xor_si128(k[0], in); + x = _mm_aesenc_si128(x, k[1]); + x = _mm_aesenc_si128(x, k[2]); + x = _mm_aesenc_si128(x, k[3]); + x = _mm_aesenc_si128(x, k[4]); + x = _mm_aesenc_si128(x, k[5]); + x = _mm_aesenc_si128(x, k[6]); + x = _mm_aesenc_si128(x, k[7]); + x = _mm_aesenc_si128(x, k[8]); + x = _mm_aesenc_si128(x, k[9]); + x = _mm_aesenclast_si128(x, k[10]); + return x; } @@ -126,3 +178,566 @@ QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, rnd4 = r[3]; } + +template<bool high> +QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm_unpackhi_epi32(x, _mm_setzero_si128()); + y = _mm_unpackhi_epi32(y, _mm_setzero_si128()); + } + else + { + x = _mm_unpacklo_epi32(x, _mm_setzero_si128()); + y = _mm_unpacklo_epi32(y, _mm_setzero_si128()); + } + + // calculate z = x ^ y << (53 - 32)) + __m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm_xor_si128(x, z); + + // convert uint64 to double + __m128d rs = _my_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE)); + rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + + +QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + // pack input and call AES + __m128i k128 = _mm_set_epi32(key3, key2, key1, key0); + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k128); + } + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + + // convert uint32 to float + rnd1 = _my_cvtepu32_ps(ctr[0]); + rnd2 = _my_cvtepu32_ps(ctr[1]); + rnd3 = _my_cvtepu32_ps(ctr[2]); + rnd4 = _my_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} + + +QUALIFIERS void aesni_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + // pack input and call AES + __m128i k128 = _mm_set_epi32(key3, key2, key1, key0); + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k128); + } + _MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void aesni_float4(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + __m128i ctr0v = _mm_set1_epi32(ctr0); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + __m128i ctr0v = _mm_set1_epi32(ctr0); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m128d & rnd1, __m128d & rnd2) +{ + __m128i ctr0v = _mm_set1_epi32(ctr0); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + __m128d ignore; + aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore); +} +#endif + + +#ifdef __AVX2__ +QUALIFIERS const std::array<__m256i,11> & aesni_roundkeys(const __m256i & k256) { + std::array<uint32,8> a; + _mm256_storeu_si256((__m256i*) a.data(), k256); + + static std::map<std::array<uint32,8>, std::array<__m256i,11>> roundkeys; + + if(roundkeys.find(a) == roundkeys.end()) { + auto rk1 = aesni_keygen(_mm256_extractf128_si256(k256, 0)); + auto rk2 = aesni_keygen(_mm256_extractf128_si256(k256, 1)); + for(int i = 0; i < 11; ++i) { + roundkeys[a][i] = _my256_set_m128i(rk2[i], rk1[i]); + } + } + return roundkeys[a]; +} + +QUALIFIERS __m256i aesni1xm128i(const __m256i & in, const __m256i & k0) { +#if defined(__VAES__) && defined(__AVX512VL__) + auto k = aesni_roundkeys(k0); + __m256i x = _mm256_xor_si256(k[0], in); + x = _mm256_aesenc_epi128(x, k[1]); + x = _mm256_aesenc_epi128(x, k[2]); + x = _mm256_aesenc_epi128(x, k[3]); + x = _mm256_aesenc_epi128(x, k[4]); + x = _mm256_aesenc_epi128(x, k[5]); + x = _mm256_aesenc_epi128(x, k[6]); + x = _mm256_aesenc_epi128(x, k[7]); + x = _mm256_aesenc_epi128(x, k[8]); + x = _mm256_aesenc_epi128(x, k[9]); + x = _mm256_aesenclast_epi128(x, k[10]); +#else + __m128i a = aesni1xm128i(_mm256_extractf128_si256(in, 0), _mm256_extractf128_si256(k0, 0)); + __m128i b = aesni1xm128i(_mm256_extractf128_si256(in, 1), _mm256_extractf128_si256(k0, 1)); + __m256i x = _my256_set_m128i(b, a); +#endif + return x; +} + +template<bool high> +QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 1)); + y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 1)); + } + else + { + x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 0)); + y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 0)); + } + + // calculate z = x ^ y << (53 - 32)) + __m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm256_xor_si256(x, z); + + // convert uint64 to double + __m256d rs = _my256_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE)); + rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + + +QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + // pack input and call AES + __m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0); + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k256); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + + // convert uint32 to float + rnd1 = _my256_cvtepu32_ps(ctr[0]); + rnd2 = _my256_cvtepu32_ps(ctr[1]); + rnd3 = _my256_cvtepu32_ps(ctr[2]); + rnd4 = _my256_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} + + +QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + // pack input and call AES + __m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0); + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k256); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm256_extractf128_si256(ctr[i], 0); + b[i] = _mm256_extractf128_si256(ctr[i], 1); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my256_set_m128i(b[i], a[i]); + } + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void aesni_float4(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + __m256i ctr0v = _mm256_set1_epi32(ctr0); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + __m256i ctr0v = _mm256_set1_epi32(ctr0); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m256d & rnd1, __m256d & rnd2) +{ +#if 0 + __m256i ctr0v = _mm256_set1_epi32(ctr0); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + __m256d ignore; + aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore); +#else + __m128d rnd1lo, rnd1hi, rnd2lo, rnd2hi; + aesni_double2(ctr0, _mm256_extractf128_si256(ctr1, 0), ctr2, ctr3, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); + rnd1 = _my256_set_m128d(rnd1hi, rnd1lo); + rnd2 = _my256_set_m128d(rnd2hi, rnd2lo); +#endif +} +#endif + + +#ifdef __AVX512F__ +QUALIFIERS const std::array<__m512i,11> & aesni_roundkeys(const __m512i & k512) { + std::array<uint32,16> a; + _mm512_storeu_si512((__m512i*) a.data(), k512); + + static std::map<std::array<uint32,16>, std::array<__m512i,11>> roundkeys; + + if(roundkeys.find(a) == roundkeys.end()) { + auto rk1 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 0)); + auto rk2 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 1)); + auto rk3 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 2)); + auto rk4 = aesni_keygen(_mm512_extracti32x4_epi32(k512, 3)); + for(int i = 0; i < 11; ++i) { + roundkeys[a][i] = _my512_set_m128i(rk4[i], rk3[i], rk2[i], rk1[i]); + } + } + return roundkeys[a]; +} + +QUALIFIERS __m512i aesni1xm128i(const __m512i & in, const __m512i & k0) { +#ifdef __VAES__ + auto k = aesni_roundkeys(k0); + __m512i x = _mm512_xor_si512(k[0], in); + x = _mm512_aesenc_epi128(x, k[1]); + x = _mm512_aesenc_epi128(x, k[2]); + x = _mm512_aesenc_epi128(x, k[3]); + x = _mm512_aesenc_epi128(x, k[4]); + x = _mm512_aesenc_epi128(x, k[5]); + x = _mm512_aesenc_epi128(x, k[6]); + x = _mm512_aesenc_epi128(x, k[7]); + x = _mm512_aesenc_epi128(x, k[8]); + x = _mm512_aesenc_epi128(x, k[9]); + x = _mm512_aesenclast_epi128(x[10], k); +#else + __m128i a = aesni1xm128i(_mm512_extracti32x4_epi32(in, 0), _mm512_extracti32x4_epi32(k0, 0)); + __m128i b = aesni1xm128i(_mm512_extracti32x4_epi32(in, 1), _mm512_extracti32x4_epi32(k0, 1)); + __m128i c = aesni1xm128i(_mm512_extracti32x4_epi32(in, 2), _mm512_extracti32x4_epi32(k0, 2)); + __m128i d = aesni1xm128i(_mm512_extracti32x4_epi32(in, 3), _mm512_extracti32x4_epi32(k0, 3)); + __m512i x = _my512_set_m128i(d, c, b, a); +#endif + return x; +} + +template<bool high> +QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 1)); + y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 1)); + } + else + { + x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 0)); + y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 0)); + } + + // calculate z = x ^ y << (53 - 32)) + __m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm512_xor_si512(x, z); + + // convert uint64 to double + __m512d rs = _mm512_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); + + return rs; +} + + +QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + // pack input and call AES + __m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0, + key3, key2, key1, key0, key3, key2, key1, key0); + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4], c[4], d[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k512); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + + // convert uint32 to float + rnd1 = _mm512_cvtepu32_ps(ctr[0]); + rnd2 = _mm512_cvtepu32_ps(ctr[1]); + rnd3 = _mm512_cvtepu32_ps(ctr[2]); + rnd4 = _mm512_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +} + + +QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + // pack input and call AES + __m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0, + key3, key2, key1, key0, key3, key2, key1, key0); + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + __m128i a[4], b[4], c[4], d[4]; + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + for (int i = 0; i < 4; ++i) + { + ctr[i] = aesni1xm128i(ctr[i], k512); + } + for (int i = 0; i < 4; ++i) + { + a[i] = _mm512_extracti32x4_epi32(ctr[i], 0); + b[i] = _mm512_extracti32x4_epi32(ctr[i], 1); + c[i] = _mm512_extracti32x4_epi32(ctr[i], 2); + d[i] = _mm512_extracti32x4_epi32(ctr[i], 3); + } + _MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]); + _MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]); + _MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]); + _MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]); + for (int i = 0; i < 4; ++i) + { + ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]); + } + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void aesni_float4(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + __m512i ctr0v = _mm512_set1_epi32(ctr0); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + __m512i ctr0v = _mm512_set1_epi32(ctr0); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void aesni_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, uint32 key2, uint32 key3, + __m512d & rnd1, __m512d & rnd2) +{ +#if 0 + __m512i ctr0v = _mm512_set1_epi32(ctr0); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + __m512d ignore; + aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore); +#else + __m256d rnd1lo, rnd1hi, rnd2lo, rnd2hi; + aesni_double2(ctr0, _mm512_extracti64x4_epi64(ctr1, 0), ctr2, ctr3, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi); + rnd1 = _my512_set_m256d(rnd1hi, rnd1lo); + rnd2 = _my512_set_m256d(rnd2hi, rnd2lo); +#endif +} +#endif + diff --git a/pystencils/include/myintrin.h b/pystencils/include/myintrin.h new file mode 100644 index 0000000000000000000000000000000000000000..a94c316c44de7aa420e4a6be807a510ee35687dd --- /dev/null +++ b/pystencils/include/myintrin.h @@ -0,0 +1,109 @@ +#pragma once + +#if defined(__SSE2__) || defined(_MSC_VER) +QUALIFIERS __m128 _my_cvtepu32_ps(const __m128i v) +{ +#ifdef __AVX512VL__ + return _mm_cvtepu32_ps(v); +#else + __m128i v2 = _mm_srli_epi32(v, 1); + __m128i v1 = _mm_and_si128(v, _mm_set1_epi32(1)); + __m128 v2f = _mm_cvtepi32_ps(v2); + __m128 v1f = _mm_cvtepi32_ps(v1); + return _mm_add_ps(_mm_add_ps(v2f, v2f), v1f); +#endif +} + +QUALIFIERS void _MY_TRANSPOSE4_EPI32(__m128i & R0, __m128i & R1, __m128i & R2, __m128i & R3) +{ + __m128i T0, T1, T2, T3; + T0 = _mm_unpacklo_epi32(R0, R1); + T1 = _mm_unpacklo_epi32(R2, R3); + T2 = _mm_unpackhi_epi32(R0, R1); + T3 = _mm_unpackhi_epi32(R2, R3); + R0 = _mm_unpacklo_epi64(T0, T1); + R1 = _mm_unpackhi_epi64(T0, T1); + R2 = _mm_unpacklo_epi64(T2, T3); + R3 = _mm_unpackhi_epi64(T2, T3); +} +#endif + +#if defined(__SSE4_1__) || defined(_MSC_VER) +#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 && !defined(__clang__) +__attribute__((optimize("no-associative-math"))) +#endif +QUALIFIERS __m128d _my_cvtepu64_pd(const __m128i x) +{ +#ifdef __AVX512VL__ + return _mm_cvtepu64_pd(x); +#else + __m128i xH = _mm_srli_epi64(x, 32); + xH = _mm_or_si128(xH, _mm_castpd_si128(_mm_set1_pd(19342813113834066795298816.))); // 2^84 + __m128i xL = _mm_blend_epi16(x, _mm_castpd_si128(_mm_set1_pd(0x0010000000000000)), 0xcc); // 2^52 + __m128d f = _mm_sub_pd(_mm_castsi128_pd(xH), _mm_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52 + return _mm_add_pd(f, _mm_castsi128_pd(xL)); +#endif +} +#endif + +#ifdef __AVX2__ +QUALIFIERS __m256i _my256_set_m128i(__m128i hi, __m128i lo) +{ +#if (!defined(__GNUC__) || __GNUC__ >= 8) || defined(__clang__) + return _mm256_set_m128i(hi, lo); +#else + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1); +#endif +} + +QUALIFIERS __m256d _my256_set_m128d(__m128d hi, __m128d lo) +{ +#if (!defined(__GNUC__) || __GNUC__ >= 8) || defined(__clang__) + return _mm256_set_m128d(hi, lo); +#else + return _mm256_insertf128_pd(_mm256_castpd128_pd256(lo), hi, 1); +#endif +} + +QUALIFIERS __m256 _my256_cvtepu32_ps(const __m256i v) +{ +#ifdef __AVX512VL__ + return _mm256_cvtepu32_ps(v); +#else + __m256i v2 = _mm256_srli_epi32(v, 1); + __m256i v1 = _mm256_and_si256(v, _mm256_set1_epi32(1)); + __m256 v2f = _mm256_cvtepi32_ps(v2); + __m256 v1f = _mm256_cvtepi32_ps(v1); + return _mm256_add_ps(_mm256_add_ps(v2f, v2f), v1f); +#endif +} + +#if !defined(__AVX512VL__) && defined(__GNUC__) && __GNUC__ >= 5 && !defined(__clang__) +__attribute__((optimize("no-associative-math"))) +#endif +QUALIFIERS __m256d _my256_cvtepu64_pd(const __m256i x) +{ +#ifdef __AVX512VL__ + return _mm256_cvtepu64_pd(x); +#else + __m256i xH = _mm256_srli_epi64(x, 32); + xH = _mm256_or_si256(xH, _mm256_castpd_si256(_mm256_set1_pd(19342813113834066795298816.))); // 2^84 + __m256i xL = _mm256_blend_epi16(x, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)), 0xcc); // 2^52 + __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52 + return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); +#endif +} +#endif + +#ifdef __AVX512F__ +QUALIFIERS __m512i _my512_set_m128i(__m128i d, __m128i c, __m128i b, __m128i a) +{ + return _mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(a), b, 1), c, 2), d, 3); +} + +QUALIFIERS __m512d _my512_set_m256d(__m256d b, __m256d a) +{ + return _mm512_insertf64x4(_mm512_castpd256_pd512(a), b, 1); +} +#endif + diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 283204921079ebfd79e022af53b17963de874cf4..b4c83669d0fda05aee1a7e018904c927a59c7ad1 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -1,7 +1,20 @@ #include <cstdint> +#if defined(__SSE4_1__) || defined(_MSC_VER) +#include <emmintrin.h> // SSE2 +#endif +#ifdef __AVX2__ +#include <immintrin.h> // AVX* +#else +#include <smmintrin.h> // SSE4 +#ifdef __FMA__ +#include <immintrin.h> // FMA +#endif +#endif + #ifndef __CUDA_ARCH__ #define QUALIFIERS inline +#include "myintrin.h" #else #define QUALIFIERS static __forceinline__ __device__ #endif @@ -78,7 +91,6 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr } - QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, float & rnd1, float & rnd2, float & rnd3, float & rnd4) @@ -100,4 +112,491 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); -} \ No newline at end of file +} + +#ifndef __CUDA_ARCH__ +#if defined(__SSE4_1__) || defined(_MSC_VER) +QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key) +{ + __m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0)); + __m128i lohi0b = _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0)); + __m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1)); + __m128i lohi1b = _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1)); + + lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8); + lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8); + lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8); + lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8); + + __m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b); + __m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b); + __m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b); + __m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b); + + ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(__m128i* key) +{ + key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0)); + key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1)); +} + +template<bool high> +QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm_unpackhi_epi32(x, _mm_setzero_si128()); + y = _mm_unpackhi_epi32(y, _mm_setzero_si128()); + } + else + { + x = _mm_unpacklo_epi32(x, _mm_setzero_si128()); + y = _mm_unpacklo_epi32(y, _mm_setzero_si128()); + } + + // calculate z = x ^ y << (53 - 32)) + __m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm_xor_si128(x, z); + + // convert uint64 to double + __m128d rs = _my_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE), _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE)); + rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + + +QUALIFIERS void philox_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)}; + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = _my_cvtepu32_ps(ctr[0]); + rnd2 = _my_cvtepu32_ps(ctr[1]); + rnd3 = _my_cvtepu32_ps(ctr[2]); + rnd4 = _my_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT), _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT)); + rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} + + +QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, + uint32 key0, uint32 key1, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)}; + __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void philox_float4(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m128 & rnd1, __m128 & rnd2, __m128 & rnd3, __m128 & rnd4) +{ + __m128i ctr0v = _mm_set1_epi32(ctr0); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m128d & rnd1lo, __m128d & rnd1hi, __m128d & rnd2lo, __m128d & rnd2hi) +{ + __m128i ctr0v = _mm_set1_epi32(ctr0); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m128d & rnd1, __m128d & rnd2) +{ + __m128i ctr0v = _mm_set1_epi32(ctr0); + __m128i ctr2v = _mm_set1_epi32(ctr2); + __m128i ctr3v = _mm_set1_epi32(ctr3); + + __m128d ignore; + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); +} +#endif + +#ifdef __AVX2__ +QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key) +{ + __m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0)); + __m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32), _mm256_set1_epi32(PHILOX_M4x32_0)); + __m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1)); + __m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32), _mm256_set1_epi32(PHILOX_M4x32_1)); + + lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8); + lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8); + lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8); + lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8); + + __m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b); + __m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b); + __m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b); + __m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b); + + ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(__m256i* key) +{ + key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0)); + key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1)); +} + +template<bool high> +QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 1)); + y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 1)); + } + else + { + x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 0)); + y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 0)); + } + + // calculate z = x ^ y << (53 - 32)) + __m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm256_xor_si256(x, z); + + // convert uint64 to double + __m256d rs = _my256_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) +#ifdef __FMA__ + rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE), _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#else + rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE)); + rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); +#endif + + return rs; +} + + +QUALIFIERS void philox_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)}; + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = _my256_cvtepu32_ps(ctr[0]); + rnd2 = _my256_cvtepu32_ps(ctr[1]); + rnd3 = _my256_cvtepu32_ps(ctr[2]); + rnd4 = _my256_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) +#ifdef __FMA__ + rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT), _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +#else + rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); + rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT)); + rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT/2.0f)); +#endif +} + + +QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, + uint32 key0, uint32 key1, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)}; + __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void philox_float4(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) +{ + __m256i ctr0v = _mm256_set1_epi32(ctr0); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m256d & rnd1lo, __m256d & rnd1hi, __m256d & rnd2lo, __m256d & rnd2hi) +{ + __m256i ctr0v = _mm256_set1_epi32(ctr0); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m256d & rnd1, __m256d & rnd2) +{ +#if 0 + __m256i ctr0v = _mm256_set1_epi32(ctr0); + __m256i ctr2v = _mm256_set1_epi32(ctr2); + __m256i ctr3v = _mm256_set1_epi32(ctr3); + + __m256d ignore; + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); +#else + __m128d rnd1lo, rnd1hi, rnd2lo, rnd2hi; + philox_double2(ctr0, _mm256_extractf128_si256(ctr1, 0), ctr2, ctr3, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); + rnd1 = _my256_set_m128d(rnd1hi, rnd1lo); + rnd2 = _my256_set_m128d(rnd2hi, rnd2lo); +#endif +} +#endif + +#ifdef __AVX512F__ +QUALIFIERS void _philox4x32round(__m512i* ctr, __m512i* key) +{ + __m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0)); + __m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32), _mm512_set1_epi32(PHILOX_M4x32_0)); + __m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1)); + __m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32), _mm512_set1_epi32(PHILOX_M4x32_1)); + + lohi0a = _mm512_shuffle_epi32(lohi0a, _MM_PERM_DBCA); + lohi0b = _mm512_shuffle_epi32(lohi0b, _MM_PERM_DBCA); + lohi1a = _mm512_shuffle_epi32(lohi1a, _MM_PERM_DBCA); + lohi1b = _mm512_shuffle_epi32(lohi1b, _MM_PERM_DBCA); + + __m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b); + __m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b); + __m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b); + __m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b); + + ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(__m512i* key) +{ + key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0)); + key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1)); +} + +template<bool high> +QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) +{ + // convert 32 to 64 bit + if (high) + { + x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 1)); + y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 1)); + } + else + { + x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 0)); + y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 0)); + } + + // calculate z = x ^ y << (53 - 32)) + __m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32)); + z = _mm512_xor_si512(x, z); + + // convert uint64 to double + __m512d rs = _mm512_cvtepu64_pd(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE), _mm512_set1_pd(TWOPOW53_INV_DOUBLE/2.0)); + + return rs; +} + + +QUALIFIERS void philox_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)}; + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = _mm512_cvtepu32_ps(ctr[0]); + rnd2 = _mm512_cvtepu32_ps(ctr[1]); + rnd3 = _mm512_cvtepu32_ps(ctr[2]); + rnd4 = _mm512_cvtepu32_ps(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); + rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT), _mm512_set1_ps(TWOPOW32_INV_FLOAT/2.0)); +} + + +QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, + uint32 key0, uint32 key1, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)}; + __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void philox_float4(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) +{ + __m512i ctr0v = _mm512_set1_epi32(ctr0); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m512d & rnd1lo, __m512d & rnd1hi, __m512d & rnd2lo, __m512d & rnd2hi) +{ + __m512i ctr0v = _mm512_set1_epi32(ctr0); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + __m512d & rnd1, __m512d & rnd2) +{ +#if 0 + __m512i ctr0v = _mm512_set1_epi32(ctr0); + __m512i ctr2v = _mm512_set1_epi32(ctr2); + __m512i ctr3v = _mm512_set1_epi32(ctr3); + + __m512d ignore; + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); +#else + __m256d rnd1lo, rnd1hi, rnd2lo, rnd2hi; + philox_double2(ctr0, _mm512_extracti64x4_epi64(ctr1, 0), ctr2, ctr3, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); + rnd1 = _my512_set_m256d(rnd1hi, rnd1lo); + rnd2 = _my512_set_m256d(rnd2hi, rnd2lo); +#endif +} +#endif +#endif + diff --git a/pystencils/rng.py b/pystencils/rng.py index f567e0c1b77b6e07d9107825364b9227cf17b26a..c1daed1d4ba43167a33650c7bcf80a2b167885b5 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -1,30 +1,11 @@ +import copy import numpy as np import sympy as sp -from pystencils import TypedSymbol +from pystencils.data_types import TypedSymbol from pystencils.astnodes import LoopOverCoordinate from pystencils.backends.cbackend import CustomCodeNode - - -def _get_rng_template(name, data_type, num_vars): - if data_type is np.float32: - c_type = "float" - elif data_type is np.float64: - c_type = "double" - template = "\n" - for i in range(num_vars): - template += f"{{result_symbols[{i}].dtype}} {{result_symbols[{i}].name}};\n" - template += ("{}_{}{}({{parameters}}, " + ", ".join(["{{result_symbols[{}].name}}"] * num_vars) + ");\n") \ - .format(name, c_type, num_vars, *tuple(range(num_vars))) - return template - - -def _get_rng_code(template, dialect, vector_instruction_set, args, result_symbols): - if dialect == 'cuda' or (dialect == 'c' and vector_instruction_set is None): - return template.format(parameters=', '.join(str(a) for a in args), - result_symbols=result_symbols) - else: - raise NotImplementedError("Not yet implemented for this backend") +from pystencils.sympyextensions import fast_subs class RNGBase(CustomCodeNode): @@ -50,7 +31,7 @@ class RNGBase(CustomCodeNode): symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args]) super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols) - self.headers = [f'"{self._name}_rand.h"'] + self.headers = [f'"{self._name.split("_")[0]}_rand.h"'] RNGBase.id += 1 @@ -58,9 +39,23 @@ class RNGBase(CustomCodeNode): def args(self): return self._args - def get_code(self, dialect, vector_instruction_set): - template = _get_rng_template(self._name, self._data_type, self._num_vars) - return _get_rng_code(template, dialect, vector_instruction_set, self.args, self.result_symbols) + def fast_subs(self, subs_dict, skip): + rng = copy.deepcopy(self) + rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args] + return rng + + def get_code(self, dialect, vector_instruction_set, print_arg): + code = "\n" + for r in self.result_symbols: + if vector_instruction_set and isinstance(self.args[1], sp.Integer): + # this vector RNG has become scalar through substitution + code += f"{r.dtype} {r.name};\n" + else: + code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \ + f"{r.name};\n" + code += (self._name + "(" + ", ".join([print_arg(a) for a in self.args] + + [r.name for r in self.result_symbols]) + ");\n") + return code def __repr__(self): return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols, @@ -68,37 +63,53 @@ class RNGBase(CustomCodeNode): class PhiloxTwoDoubles(RNGBase): - _name = "philox" + _name = "philox_double2" _data_type = np.float64 _num_vars = 2 _num_keys = 2 class PhiloxFourFloats(RNGBase): - _name = "philox" + _name = "philox_float4" _data_type = np.float32 _num_vars = 4 _num_keys = 2 class AESNITwoDoubles(RNGBase): - _name = "aesni" + _name = "aesni_double2" _data_type = np.float64 _num_vars = 2 _num_keys = 4 class AESNIFourFloats(RNGBase): - _name = "aesni" + _name = "aesni_float4" _data_type = np.float32 _num_vars = 4 _num_keys = 4 -def random_symbol(assignment_list, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, *args, **kwargs): +def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles, + time_step=TypedSymbol("time_step", np.uint32), offsets=None): + """Return a symbol generator for random numbers + + Args: + assignment_list: the subexpressions member of an AssignmentCollection, into which helper variables assignments + will be inserted + dim: 2 or 3 for two or three spatial dimensions + seed: an integer or TypedSymbol(..., np.uint32) to seed the random number generator. If you create multiple + symbol generators, please pass them different seeds so you don't get the same stream of random numbers! + rng_node: which random number generator to use (PhiloxTwoDoubles, PhiloxFourFloats, AESNITwoDoubles, + AESNIFourFloats). + time_step: TypedSymbol(..., np.uint32) that indicates the number of the current time step + offsets: tuple of offsets (constant integers or TypedSymbol(..., np.uint32)) that give the global coordinates + of the local origin + """ counter = 0 while True: - node = rng_node(*args, keys=(counter, seed), **kwargs) + keys = (counter, seed) + (0,) * (rng_node._num_keys - 2) + node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets) inserted = False for symbol in node.result_symbols: if not inserted: diff --git a/pystencils_tests/test_interpolation.py b/pystencils_tests/test_interpolation.py index 19201c7c9016be156f2ed73828eb65f1b9d6d229..257d4af802995dd43e1ca29651f31c0acd5d85cf 100644 --- a/pystencils_tests/test_interpolation.py +++ b/pystencils_tests/test_interpolation.py @@ -77,14 +77,7 @@ def test_scale_interpolation(): pyconrad.imshow(out, "out " + address_mode) -@pytest.mark.parametrize('address_mode', - ['border', - 'clamp', - pytest.param('warp', marks=pytest.mark.xfail( - reason="requires interpolation-refactoring branch")), - pytest.param('mirror', marks=pytest.mark.xfail( - reason="requires interpolation-refactoring branch")), - ]) +@pytest.mark.parametrize('address_mode', ['border', 'clamp']) def test_rotate_interpolation(address_mode): """ 'wrap', 'mirror' currently fails on new sympy due to conjugate() @@ -144,18 +137,13 @@ def test_rotate_interpolation_gpu(dtype, address_mode, use_textures): f"out {address_mode} texture:{use_textures} {type_map[dtype]}") -@pytest.mark.parametrize('address_mode', ['border', 'wrap', - pytest.param('warp', marks=pytest.mark.xfail( - reason="% printed as fmod on old sympy")), - pytest.param('mirror', marks=pytest.mark.xfail( - reason="% printed as fmod on old sympy")), - ]) +@pytest.mark.parametrize('address_mode', ['border', 'wrap', 'mirror']) @pytest.mark.parametrize('dtype', [np.float64, np.float32, np.int32]) @pytest.mark.parametrize('use_textures', ('use_textures', False,)) def test_shift_interpolation_gpu(address_mode, dtype, use_textures): sver = sympy.__version__.split(".") - if (int(sver[0]) == 1 and int(sver[1]) < 2) and address_mode in ['mirror', 'warp']: - pytest.skip() + if (int(sver[0]) == 1 and int(sver[1]) < 2) and address_mode == 'mirror': + pytest.skip("% printed as fmod on old sympy") pytest.importorskip('pycuda') import pycuda.gpuarray as gpuarray diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index 322718d1b786d28554863fdd94df9cbbd2ff02fa..cf1de7c27f797b8905e8c03eb4a334d1fecebcda 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -4,108 +4,197 @@ import pytest import pystencils as ps from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol - - -# curand_Philox4x32_10(make_uint4(124, i, j, 0), make_uint2(0, 0)) -philox_reference = np.array([[[3576608082, 1252663339, 1987745383, 348040302], - [1032407765, 970978240, 2217005168, 2424826293]], - [[2958765206, 3725192638, 2623672781, 1373196132], - [ 850605163, 1694561295, 3285694973, 2799652583]]]) - -@pytest.mark.parametrize('target', ('cpu', 'gpu')) -def test_philox_double(target): +from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets +from pystencils.cpu.cpujit import get_compiler_config +from pystencils.data_types import TypedSymbol + +RNGs = {('philox', 'float'): PhiloxFourFloats, ('philox', 'double'): PhiloxTwoDoubles, + ('aesni', 'float'): AESNIFourFloats, ('aesni', 'double'): AESNITwoDoubles} + +instruction_sets = get_supported_instruction_sets() +if get_compiler_config()['os'] == 'windows': + # skip instruction sets supported by CPU but not the compiler + if 'avx' in instruction_sets and ('/arch:avx2' not in get_compiler_config()['flags'].lower() + or '/arch:avx512' not in get_compiler_config()['flags'].lower()): + instruction_sets.remove('avx') + if 'avx512' in instruction_sets and '/arch:avx512' not in get_compiler_config()['flags'].lower(): + instruction_sets.remove('avx512') + + +@pytest.mark.parametrize('target,rng', (('cpu', 'philox'), ('cpu', 'aesni'), ('gpu', 'philox'))) +@pytest.mark.parametrize('precision', ('float', 'double')) +@pytest.mark.parametrize('dtype', ('float', 'double')) +def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None): if target == 'gpu': pytest.importorskip('pycuda') + if rng == 'aesni' and len(keys) == 2: + keys *= 2 + if offset_values is None: + offset_values = offsets dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target=target) - f = dh.add_array("f", values_per_cell=2) + f = dh.add_array("f", values_per_cell=4 if precision == 'float' else 2, + dtype=np.float32 if dtype == 'float' else np.float64) + dh.fill(f.name, 42.0) - dh.fill('f', 42.0) - - philox_node = PhiloxTwoDoubles(dh.dim) - assignments = [philox_node, - ps.Assignment(f(0), philox_node.result_symbols[0]), - ps.Assignment(f(1), philox_node.result_symbols[1])] + rng_node = RNGs[(rng, precision)](dh.dim, offsets=offsets, keys=keys) + assignments = [rng_node] + [ps.Assignment(f(i), s) for i, s in enumerate(rng_node.result_symbols)] kernel = ps.create_kernel(assignments, target=dh.default_target).compile() dh.all_to_gpu() - dh.run_kernel(kernel, time_step=124) + kwargs = {'time_step': t} + if offset_values != offsets: + kwargs.update({k.name: v for k, v in zip(offsets, offset_values)}) + dh.run_kernel(kernel, **kwargs) dh.all_to_cpu() - - arr = dh.gather_array('f') + arr = dh.gather_array(f.name) assert np.logical_and(arr <= 1.0, arr >= 0).all() - x = philox_reference[:,:,0::2] - y = philox_reference[:,:,1::2] - z = x ^ y << (53 - 32) - double_reference = z * 2.**-53 + 2.**-54 - assert(np.allclose(arr, double_reference, rtol=0, atol=np.finfo(np.float64).eps)) - - -@pytest.mark.parametrize('target', ('cpu', 'gpu')) -def test_philox_float(target): - if target == 'gpu': - pytest.importorskip('pycuda') - - dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target=target) - f = dh.add_array("f", values_per_cell=4) - - dh.fill('f', 42.0) - - philox_node = PhiloxFourFloats(dh.dim) - assignments = [philox_node] + [ps.Assignment(f(i), philox_node.result_symbols[i]) for i in range(4)] + if rng == 'philox' and t == 124 and offsets == (0, 0) and keys == (0, 0) and dh.shape == (2, 2): + int_reference = np.array([[[3576608082, 1252663339, 1987745383, 348040302], + [1032407765, 970978240, 2217005168, 2424826293]], + [[2958765206, 3725192638, 2623672781, 1373196132], + [850605163, 1694561295, 3285694973, 2799652583]]]) + else: + pytest.importorskip('randomgen') + if rng == 'aesni': + from randomgen import AESCounter + int_reference = np.empty(dh.shape + (4,), dtype=int) + for x in range(dh.shape[0]): + for y in range(dh.shape[1]): + r = AESCounter(counter=t + (x + offset_values[0]) * 2 ** 32 + (y + offset_values[1]) * 2 ** 64, + key=keys[0] + keys[1] * 2 ** 32 + keys[2] * 2 ** 64 + keys[3] * 2 ** 96, + mode="sequence") + a, b = r.random_raw(size=2) + int_reference[x, y, :] = [a % 2 ** 32, a // 2 ** 32, b % 2 ** 32, b // 2 ** 32] + else: + from randomgen import Philox + int_reference = np.empty(dh.shape + (4,), dtype=int) + for x in range(dh.shape[0]): + for y in range(dh.shape[1]): + r = Philox(counter=t + (x + offset_values[0]) * 2 ** 32 + (y + offset_values[1]) * 2 ** 64, + key=keys[0] + keys[1] * 2 ** 32, number=4, width=32) + r.advance(-4, counter=False) + int_reference[x, y, :] = r.random_raw(size=4) + + if precision == 'float' or dtype == 'float': + eps = np.finfo(np.float32).eps + else: + eps = np.finfo(np.float64).eps + if rng == 'aesni': # precision appears to be slightly worse + eps = max(1e-12, 2 * eps) + + if precision == 'float': + reference = int_reference * 2. ** -32 + 2. ** -33 + else: + x = int_reference[:, :, 0::2] + y = int_reference[:, :, 1::2] + z = x ^ y << (53 - 32) + reference = z * 2. ** -53 + 2. ** -54 + assert np.allclose(arr, reference, rtol=0, atol=eps) + + +@pytest.mark.parametrize('vectorized', (False, True)) +@pytest.mark.parametrize('kind', ('value', 'symbol')) +def test_rng_offsets(kind, vectorized): + if vectorized: + test = test_rng_vectorized + if not instruction_sets: + pytest.skip("cannot detect CPU instruction set") + else: + test = test_rng + if kind == 'value': + test(instruction_sets[0] if vectorized else 'cpu', 'philox', 'float', 'float', t=8, + offsets=(6, 7), keys=(5, 309)) + elif kind == 'symbol': + offsets = (TypedSymbol("x0", np.uint32), TypedSymbol("y0", np.uint32)) + test(instruction_sets[0] if vectorized else 'cpu', 'philox', 'float', 'float', t=8, + offsets=offsets, offset_values=(6, 7), keys=(5, 309)) + + +@pytest.mark.parametrize('target', instruction_sets) +@pytest.mark.parametrize('rng', ('philox', 'aesni')) +@pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double'))) +def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None): + cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target} + + dh = ps.create_data_handling((17, 17), default_ghost_layers=0, default_target='cpu') + f = dh.add_array("f", values_per_cell=4 if precision == 'float' else 2, + dtype=np.float32 if dtype == 'float' else np.float64, alignment=True) + dh.fill(f.name, 42.0) + ref = dh.add_array("ref", values_per_cell=4 if precision == 'float' else 2) + + rng_node = RNGs[(rng, precision)](dh.dim, offsets=offsets) + assignments = [rng_node] + [ps.Assignment(ref(i), s) for i, s in enumerate(rng_node.result_symbols)] kernel = ps.create_kernel(assignments, target=dh.default_target).compile() - dh.all_to_gpu() - dh.run_kernel(kernel, time_step=124) - dh.all_to_cpu() - arr = dh.gather_array('f') - assert np.logical_and(arr <= 1.0, arr >= 0).all() + kwargs = {'time_step': t} + if offset_values is not None: + kwargs.update({k.name: v for k, v in zip(offsets, offset_values)}) + dh.run_kernel(kernel, **kwargs) - float_reference = philox_reference * 2.**-32 + 2.**-33 - assert(np.allclose(arr, float_reference, rtol=0, atol=np.finfo(np.float32).eps)) + rng_node = RNGs[(rng, precision)](dh.dim, offsets=offsets) + assignments = [rng_node] + [ps.Assignment(f(i), s) for i, s in enumerate(rng_node.result_symbols)] + kernel = ps.create_kernel(assignments, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() -def test_aesni_double(): - dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu") - f = dh.add_array("f", values_per_cell=2) + dh.run_kernel(kernel, **kwargs) - dh.fill('f', 42.0) + ref_data = dh.gather_array(ref.name) + data = dh.gather_array(f.name) - aesni_node = AESNITwoDoubles(dh.dim) - assignments = [aesni_node, - ps.Assignment(f(0), aesni_node.result_symbols[0]), - ps.Assignment(f(1), aesni_node.result_symbols[1])] - kernel = ps.create_kernel(assignments, target=dh.default_target).compile() + assert np.allclose(ref_data, data) - dh.all_to_gpu() - dh.run_kernel(kernel, time_step=124) - dh.all_to_cpu() - - arr = dh.gather_array('f') - assert np.logical_and(arr <= 1.0, arr >= 0).all() - - -def test_aesni_float(): - dh = ps.create_data_handling((2, 2), default_ghost_layers=0, default_target="cpu") - f = dh.add_array("f", values_per_cell=4) - - dh.fill('f', 42.0) - - aesni_node = AESNIFourFloats(dh.dim) - assignments = [aesni_node] + [ps.Assignment(f(i), aesni_node.result_symbols[i]) for i in range(4)] - kernel = ps.create_kernel(assignments, target=dh.default_target).compile() - - dh.all_to_gpu() - dh.run_kernel(kernel, time_step=124) - dh.all_to_cpu() - arr = dh.gather_array('f') - assert np.logical_and(arr <= 1.0, arr >= 0).all() -def test_staggered(): +@pytest.mark.parametrize('vectorized', (False, True)) +def test_rng_symbol(vectorized): + """Make sure that the RNG symbol generator generates symbols and that the resulting code compiles""" + if vectorized: + if not instruction_sets: + pytest.skip("cannot detect CPU instruction set") + else: + cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, + 'instruction_set': instruction_sets[0]} + else: + cpu_vectorize_info = None + + dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu") + f = dh.add_array("f", values_per_cell=2 * dh.dim, alignment=True) + ac = ps.AssignmentCollection([ps.Assignment(f(i), 0) for i in range(f.shape[-1])]) + rng_symbol_gen = random_symbol(ac.subexpressions, dim=dh.dim) + for i in range(f.shape[-1]): + ac.main_assignments[i] = ps.Assignment(ac.main_assignments[i].lhs, next(rng_symbol_gen)) + symbols = [a.rhs for a in ac.main_assignments] + assert len(symbols) == f.shape[-1] and len(set(symbols)) == f.shape[-1] + ps.create_kernel(ac, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() + + +@pytest.mark.parametrize('vectorized', (False, True)) +def test_staggered(vectorized): """Make sure that the RNG counter can be substituted during loop cutting""" + dh = ps.create_data_handling((8, 8), default_ghost_layers=0, default_target="cpu") j = dh.add_array("j", values_per_cell=dh.dim, field_type=ps.FieldType.STAGGERED_FLUX) a = ps.AssignmentCollection([ps.Assignment(j.staggered_access(n), 0) for n in j.staggered_stencil]) - rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim) + rng_symbol_gen = random_symbol(a.subexpressions, dim=dh.dim, rng_node=AESNITwoDoubles) a.main_assignments[0] = ps.Assignment(a.main_assignments[0].lhs, next(rng_symbol_gen)) kernel = ps.create_staggered_kernel(a, target=dh.default_target).compile() + + if not vectorized: + return + if not instruction_sets: + pytest.skip("cannot detect CPU instruction set") + pytest.importorskip('islpy') + cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': False, + 'instruction_set': instruction_sets[0]} + + dh.fill(j.name, 867) + dh.run_kernel(kernel, seed=5, time_step=309) + ref_data = dh.gather_array(j.name) + + kernel2 = ps.create_staggered_kernel(a, target=dh.default_target, cpu_vectorize_info=cpu_vectorize_info).compile() + + dh.fill(j.name, 867) + dh.run_kernel(kernel2, seed=5, time_step=309) + data = dh.gather_array(j.name) + + assert np.allclose(ref_data, data) diff --git a/setup.py b/setup.py index bce6eed2ddba02c6af29a01d6bb13efb3ad28cdf..0b916e876a253ffa72a6fe229c65573f2b50beb6 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,8 @@ setuptools.setup(name='pystencils', 'flake8', 'nbformat', 'nbconvert', - 'ipython'], + 'ipython', + 'randomgen>=1.18'], python_requires=">=3.6", cmdclass={