diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py index 6c388f3e4cb798297f1902b06ef82b0909d191a5..26f61e909ee115c5b9e081877a232fdb082d2b7e 100644 --- a/pystencils/backends/arm_instruction_sets.py +++ b/pystencils/backends/arm_instruction_sets.py @@ -42,12 +42,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q if q_registers is True: q_reg = 'q' width = 128 // bits[data_type] - intwidth = 128 // bits[data_type] + intwidth = 128 // bits['int'] suffix = f'q_f{bits[data_type]}' else: q_reg = '' width = 64 // bits[data_type] - intwidth = 64 // bits[data_type] + intwidth = 64 // bits['int'] suffix = f'_f{bits[data_type]}' result = dict() @@ -61,9 +61,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q result[intrinsic_id] = 'v' + name + suffix + arg_string result['makeVecConst'] = f'vdup{q_reg}_n_f{bits[data_type]}' + '({0})' - result['makeVec'] = f'vdup{q_reg}_n_f{bits[data_type]}' + '({0})' + result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in range(width)]) + \ + ')' result['makeVecConstInt'] = f'vdup{q_reg}_n_s{bits["int"]}' + '({0})' - result['makeVecInt'] = f'vdup{q_reg}_n_s{bits["int"]}' + '({0})' + result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})' result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})" @@ -74,7 +75,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q result[data_type] = f'float{bits[data_type]}x{width}_t' result['int'] = f'int{bits["int"]}x{bits[data_type]}_t' result['bool'] = f'uint{bits[data_type]}x{width}_t' - result['headers'] = ['<arm_neon.h>'] + result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"'] result['!='] = f'vmvn{q_reg}_u{bits[data_type]}({result["=="]})' diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 0af12adaf81584c93bff037705c1edcf67bbab93..a7d2b76d8981732da019b2bc9f9acfc52a104d2f 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -155,9 +155,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_node.step = vector_width loop_node.subs(substitutions) vector_int_width = ast_node.instruction_set['intwidth'] - vector_loop_counter = cast_func((loop_counter_symbol,) * vector_int_width, - VectorType(loop_counter_symbol.dtype, vector_int_width)) + \ - cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width)) + vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \ + + cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width)) fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) diff --git a/pystencils/include/arm_neon_helpers.h b/pystencils/include/arm_neon_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..ba6cbc2d7bae45591bcec580b98394c4f6830339 --- /dev/null +++ b/pystencils/include/arm_neon_helpers.h @@ -0,0 +1,19 @@ +#include <arm_neon.h> + +inline float32x4_t makeVec_f32(float a, float b, float c, float d) +{ + alignas(16) float data[4] = {a, b, c, d}; + return vld1q_f32(data); +} + +inline float64x2_t makeVec_f64(double a, double b) +{ + alignas(16) double data[2] = {a, b}; + return vld1q_f64(data); +} + +inline int32x4_t makeVec_s32(int a, int b, int c, int d) +{ + alignas(16) int data[4] = {a, b, c, d}; + return vld1q_s32(data); +} diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 2950717738d93f19410dfcc17786a687b8ee8226..71da35715d3fb5ef9d24329cde16aa82d28a9070 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -12,6 +12,10 @@ #endif #endif +#ifdef __ARM_NEON +#include <arm_neon.h> +#endif + #ifndef __CUDA_ARCH__ #define QUALIFIERS inline #include "myintrin.h" @@ -277,6 +281,161 @@ QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2, uint32 ct } #endif +#if defined(__ARM_NEON) +QUALIFIERS void _philox4x32round(uint32x4_t* ctr, uint32x4_t* key) +{ + uint32x4_t lohi0a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[0]), vdup_n_u32(PHILOX_M4x32_0))); + uint32x4_t lohi0b = vreinterpretq_u32_u64(vmull_high_u32(ctr[0], vdupq_n_u32(PHILOX_M4x32_0))); + uint32x4_t lohi1a = vreinterpretq_u32_u64(vmull_u32(vget_low_u32(ctr[2]), vdup_n_u32(PHILOX_M4x32_1))); + uint32x4_t lohi1b = vreinterpretq_u32_u64(vmull_high_u32(ctr[2], vdupq_n_u32(PHILOX_M4x32_1))); + + uint32x4_t lo0 = vuzp1q_u32(lohi0a, lohi0b); + uint32x4_t lo1 = vuzp1q_u32(lohi1a, lohi1b); + uint32x4_t hi0 = vuzp2q_u32(lohi0a, lohi0b); + uint32x4_t hi1 = vuzp2q_u32(lohi1a, lohi1b); + + ctr[0] = veorq_u32(veorq_u32(hi1, ctr[1]), key[0]); + ctr[1] = lo1; + ctr[2] = veorq_u32(veorq_u32(hi0, ctr[3]), key[1]); + ctr[3] = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(uint32x4_t* key) +{ + key[0] = vaddq_u32(key[0], vdupq_n_u32(PHILOX_W32_0)); + key[1] = vaddq_u32(key[1], vdupq_n_u32(PHILOX_W32_1)); +} + +template<bool high> +QUALIFIERS float64x2_t _uniform_double_hq(uint32x4_t x, uint32x4_t y) +{ + // convert 32 to 64 bit + if (high) + { + x = vzip2q_u32(x, vdupq_n_u32(0)); + y = vzip2q_u32(y, vdupq_n_u32(0)); + } + else + { + x = vzip1q_u32(x, vdupq_n_u32(0)); + y = vzip1q_u32(y, vdupq_n_u32(0)); + } + + // calculate z = x ^ y << (53 - 32)) + uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32); + z = veorq_u64(vreinterpretq_u64_u32(x), z); + + // convert uint64 to double + float64x2_t rs = vcvtq_f64_u64(z); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE/2.0), vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs); + + return rs; +} + + +QUALIFIERS void philox_float4(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3, + uint32 key0, uint32 key1, + float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4) +{ + uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)}; + uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + // convert uint32 to float + rnd1 = vcvtq_f32_u32(ctr[0]); + rnd2 = vcvtq_f32_u32(ctr[1]); + rnd3 = vcvtq_f32_u32(ctr[2]); + rnd4 = vcvtq_f32_u32(ctr[3]); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd1); + rnd2 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd2); + rnd3 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd3); + rnd4 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd4); +} + + +QUALIFIERS void philox_double2(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3, + uint32 key0, uint32 key1, + float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi) +{ + uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)}; + uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3}; + _philox4x32round(ctr, key); // 1 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 2 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 3 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 4 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 5 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 6 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 7 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 8 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 + _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); + rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); + rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]); + rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]); +} + +QUALIFIERS void philox_float4(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4) +{ + uint32x4_t ctr0v = vdupq_n_u32(ctr0); + uint32x4_t ctr2v = vdupq_n_u32(ctr2); + uint32x4_t ctr3v = vdupq_n_u32(ctr3); + + philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_float4(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4) +{ + philox_float4(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi) +{ + uint32x4_t ctr0v = vdupq_n_u32(ctr0); + uint32x4_t ctr2v = vdupq_n_u32(ctr2); + uint32x4_t ctr3v = vdupq_n_u32(ctr3); + + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + float64x2_t & rnd1, float64x2_t & rnd2) +{ + uint32x4_t ctr0v = vdupq_n_u32(ctr0); + uint32x4_t ctr2v = vdupq_n_u32(ctr2); + uint32x4_t ctr3v = vdupq_n_u32(ctr3); + + float64x2_t ignore; + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); +} + +QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + float64x2_t & rnd1, float64x2_t & rnd2) +{ + philox_double2(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2); +} +#endif + #ifdef __AVX2__ QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key) { diff --git a/pystencils/rng.py b/pystencils/rng.py index fed90aceff97bb94e82ae8ca054280c0140f203d..f5a970e96c8539487515afa53705cf6cab280961 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -58,8 +58,8 @@ class RNGBase(CustomCodeNode): return code def __repr__(self): - return (", ".join(['{}'] * self._num_vars) + " \\leftarrow {}RNG").format(*self.result_symbols, - self._name.capitalize()) + return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \ + self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")" class PhiloxTwoDoubles(RNGBase): diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index 30b55a66be082454e48de57284e1ed9678e6e19e..f22aaeb7f338315d354739da7cdca872a728dfe1 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -27,6 +27,8 @@ if get_compiler_config()['os'] == 'windows': def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None): if target == 'gpu': pytest.importorskip('pycuda') + if instruction_sets and 'neon' in instruction_sets and rng == 'aesni': + pytest.xfail('AES not yet implemented for ARM Neon') if rng == 'aesni' and len(keys) == 2: keys *= 2 if offset_values is None: @@ -116,6 +118,8 @@ def test_rng_offsets(kind, vectorized): @pytest.mark.parametrize('rng', ('philox', 'aesni')) @pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double'))) def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None): + if target == 'neon' and rng == 'aesni': + pytest.xfail('AES not yet implemented for ARM Neon') cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target} dh = ps.create_data_handling((17, 17), default_ghost_layers=0, default_target='cpu')