From 7238f9bf7a26d256892e539768025ae99019fe65 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Tue, 20 Apr 2021 12:11:44 +0000
Subject: [PATCH] SVE vectorization

---
 pystencils/backends/arm_instruction_sets.py  | 115 +++++++++----
 pystencils/backends/cbackend.py              |   6 +-
 pystencils/backends/simd_instruction_sets.py |  19 ++-
 pystencils/backends/x86_instruction_sets.py  |   1 -
 pystencils/cpu/cpujit.py                     |  13 +-
 pystencils/include/arm_neon_helpers.h        |  13 ++
 pystencils/include/philox_rand.h             | 160 +++++++++++++++++++
 pystencils_tests/test_random.py              |  14 +-
 pystencils_tests/test_vectorization.py       |   4 +-
 9 files changed, 295 insertions(+), 50 deletions(-)

diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py
index a386253a0..9f7b4ee22 100644
--- a/pystencils/backends/arm_instruction_sets.py
+++ b/pystencils/backends/arm_instruction_sets.py
@@ -1,6 +1,8 @@
-def get_argument_string(function_shortcut):
+def get_argument_string(function_shortcut, first=''):
     args = function_shortcut[function_shortcut.index('[') + 1: -1]
     arg_string = "("
+    if first:
+        arg_string += first + ', '
     for arg in args.split(","):
         arg = arg.strip()
         if not arg:
@@ -14,8 +16,17 @@ def get_argument_string(function_shortcut):
 
 
 def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
-    if instruction_set != 'neon':
+    if instruction_set != 'neon' and not instruction_set.startswith('sve'):
         raise NotImplementedError(instruction_set)
+    if instruction_set == 'sve':
+        raise NotImplementedError("sizeless SVE is not implemented")
+
+    if instruction_set.startswith('sve'):
+        cmp = 'cmp'
+        bitwidth = int(instruction_set[3:])
+    elif instruction_set == 'neon':
+        cmp = 'c'
+        bitwidth = 128
 
     base_names = {
         '+': 'add[0, 1]',
@@ -30,58 +41,94 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         'storeA': 'st1[0, 1]',
 
         'abs': 'abs[0]',
-        '==': 'ceq[0, 1]',
-        '<=': 'cle[0, 1]',
-        '<': 'clt[0, 1]',
-        '>=': 'cge[0, 1]',
-        '>': 'cgt[0, 1]',
+        '==': f'{cmp}eq[0, 1]',
+        '!=': f'{cmp}eq[0, 1]',
+        '<=': f'{cmp}le[0, 1]',
+        '<': f'{cmp}lt[0, 1]',
+        '>=': f'{cmp}ge[0, 1]',
+        '>': f'{cmp}gt[0, 1]',
     }
 
     bits = {'double': 64,
             'float': 32,
             'int': 32}
 
-    width = 128 // bits[data_type]
-    intwidth = 128 // bits['int']
-    suffix = f'q_f{bits[data_type]}'
+    width = bitwidth // bits[data_type]
+    intwidth = bitwidth // bits['int']
+    if instruction_set.startswith('sve'):
+        prefix = 'sv'
+        suffix = f'_f{bits[data_type]}' 
+    elif instruction_set == 'neon':
+        prefix = 'v'
+        suffix = f'q_f{bits[data_type]}' 
 
     result = dict()
-    result['bytes'] = 16
+    result['bytes'] = bitwidth // 8
+
+    predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
+    int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
 
     for intrinsic_id, function_shortcut in base_names.items():
         function_shortcut = function_shortcut.strip()
         name = function_shortcut[:function_shortcut.index('[')]
 
-        arg_string = get_argument_string(function_shortcut)
+        arg_string = get_argument_string(function_shortcut, first=predicate if prefix == 'sv' else '')
+        if prefix == 'sv' and not name.startswith('ld') and not name.startswith('st') and not name.startswith(cmp):
+            undef = '_x'
+        else:
+            undef = ''
 
-        result[intrinsic_id] = 'v' + name + suffix + arg_string
+        result[intrinsic_id] = prefix + name + suffix + undef + arg_string
 
-    result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
-    result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in range(width)]) + \
-        ')'
-    result['makeVecConstInt'] = f'vdupq_n_s{bits["int"]}' + '({0})'
-    result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})'
+    result['width'] = width
+    result['intwidth'] = intwidth
 
-    result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})"
+    if instruction_set.startswith('sve'):
+        result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
+        result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
+        result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
 
-    result['rsqrt'] = None
+        result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
 
-    result['width'] = width
-    result['intwidth'] = intwidth
-    result[data_type] = f'float{bits[data_type]}x{width}_t'
-    result['int'] = f'int{bits["int"]}x{bits[data_type]}_t'
-    result['bool'] = f'uint{bits[data_type]}x{width}_t'
-    result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"']
+        result[data_type] = f'svfloat{bits[data_type]}_st'
+        result['int'] = f'svint{bits["int"]}_st'
+        result['bool'] = 'svbool_st'
+
+        result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"']
+
+        result['&'] = f'svand_b_z({predicate},' + ' {0}, {1})'
+        result['|'] = f'svorr_b_z({predicate},' + ' {0}, {1})'
+        result['blendv'] = f'svsel_f{bits[data_type]}' + '({2}, {1}, {0})'
+        result['any'] = f'svptest_any({predicate}, {{0}})'
+        result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
+
+        result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
+    else:
+        result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
+        result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in
+                                                                             range(width)]) + ')'
+        result['makeVecConstInt'] = f'vdupq_n_s{bits["int"]}' + '({0})'
+        result['makeVecInt'] = f'makeVec_s{bits["int"]}' + '({0}, {1}, {2}, {3})'
+
+        result['+int'] = f"vaddq_s{bits['int']}" + "({0}, {1})"
+
+        result[data_type] = f'float{bits[data_type]}x{width}_t'
+        result['int'] = f'int{bits["int"]}x{intwidth}_t'
+        result['bool'] = f'uint{bits[data_type]}x{width}_t'
+
+        result['headers'] = ['<arm_neon.h>', '"arm_neon_helpers.h"']
 
-    result['!='] = f'vmvnq_u{bits[data_type]}({result["=="]})'
+        result['!='] = f'vmvnq_u{bits[data_type]}({result["=="]})'
 
-    result['&'] = f'vandq_u{bits[data_type]}' + '({0}, {1})'
-    result['|'] = f'vorrq_u{bits[data_type]}' + '({0}, {1})'
-    result['blendv'] = f'vbslq_f{bits[data_type]}' + '({2}, {1}, {0})'
-    result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
-    result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
+        result['&'] = f'vandq_u{bits[data_type]}' + '({0}, {1})'
+        result['|'] = f'vorrq_u{bits[data_type]}' + '({0}, {1})'
+        result['blendv'] = f'vbslq_f{bits[data_type]}' + '({2}, {1}, {0})'
+        result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
+        result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
 
-    result['cachelineSize'] = 'cachelineSize()'
-    result['cachelineZero'] = 'cachelineZero((void*) {0})'
+    if bitwidth & (bitwidth - 1) == 0:
+        # only power-of-2 vector sizes will evenly divide a cacheline
+        result['cachelineSize'] = 'cachelineSize()'
+        result['cachelineZero'] = 'cachelineZero((void*) {0})'
 
     return result
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 9d7aa3cc3..5aabd83d6 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -610,6 +610,10 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                     is_integer = get_type_of_expression(arg[0]) == create_type("int")
                     printed_args = [self._print(a) for a in arg]
                     instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
+                    if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
+                        increments = np.array(arg)[1:] - np.array(arg)[:-1]
+                        if len(set(increments)) == 1:
+                            return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0])
                     return self.instruction_set[instruction].format(*printed_args)
                 else:
                     is_boolean = get_type_of_expression(arg) == create_type("bool")
@@ -628,7 +632,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         elif expr.func == fast_inv_sqrt:
             result = self._scalarFallback('_print_Function', expr)
             if not result:
-                if self.instruction_set['rsqrt']:
+                if 'rsqrt' in self.instruction_set:
                     return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
                 else:
                     return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py
index 9469dc59e..b552da0e9 100644
--- a/pystencils/backends/simd_instruction_sets.py
+++ b/pystencils/backends/simd_instruction_sets.py
@@ -1,4 +1,6 @@
+import math
 import platform
+from ctypes import CDLL
 
 from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
 from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
@@ -6,7 +8,7 @@ from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_
 
 
 def get_vector_instruction_set(data_type='double', instruction_set='avx'):
-    if instruction_set in ['neon', 'sve']:
+    if instruction_set in ['neon'] or instruction_set.startswith('sve'):
         return get_vector_instruction_set_arm(data_type, instruction_set)
     elif instruction_set in ['vsx']:
         return get_vector_instruction_set_ppc(data_type, instruction_set)
@@ -47,6 +49,7 @@ def get_supported_instruction_sets():
     required_avx_flags = {'avx', 'avx2'}
     required_avx512_flags = {'avx512f'}
     required_neon_flags = {'neon'}
+    required_sve_flags = {'sve'}
     flags = set(get_cpu_info()['flags'])
     if flags.issuperset(required_sse_flags):
         result.append("sse")
@@ -56,6 +59,20 @@ def get_supported_instruction_sets():
         result.append("avx512")
     if flags.issuperset(required_neon_flags):
         result.append("neon")
+    if flags.issuperset(required_sve_flags):
+        if platform.system() == 'Linux':
+            libc = CDLL('libc.so.6')
+            native_length = 8 * libc.prctl(51, 0, 0, 0, 0)  # PR_SVE_GET_VL
+            if native_length < 0:
+                raise OSError("SVE length query failed")
+            pwr2_length = int(2**math.floor(math.log2(native_length)))
+            if pwr2_length % 256 == 0:
+                result.append(f"sve{pwr2_length//2}")
+            if native_length != pwr2_length:
+                result.append(f"sve{pwr2_length}")
+            result.append(f"sve{native_length}")
+        else:
+            result.append("sve")
     return result
 
 
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index 86196aad7..78515e1a2 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -133,7 +133,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         'float': "_" + pre,
     }
 
-    result['rsqrt'] = None
     bit_width = result['width'] * (64 if data_type == 'double' else 32)
     result['double'] = f"__m{bit_width}d"
     result['float'] = f"__m{bit_width}"
diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py
index 68ca25902..fc9b810c3 100644
--- a/pystencils/cpu/cpujit.py
+++ b/pystencils/cpu/cpujit.py
@@ -558,9 +558,9 @@ class ExtensionModuleCode:
         print(self._code_string, file=file)
 
 
-def compile_module(code, code_hash, base_dir):
+def compile_module(code, code_hash, base_dir, compile_flags=[]):
     compiler_config = get_compiler_config()
-    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
+    extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] + compile_flags
 
     if compiler_config['os'].lower() == 'windows':
         lib_suffix = '.pyd'
@@ -620,12 +620,17 @@ def compile_and_load(ast, custom_backend=None):
     code.create_code_string(compiler_config['restrict_qualifier'], function_prefix)
     code_hash_str = code.get_hash_of_code()
 
+    compile_flags = []
+    if ast.instruction_set and 'compile_flags' in ast.instruction_set:
+        compile_flags = ast.instruction_set['compile_flags']
+
     if cache_config['object_cache'] is False:
         with TemporaryDirectory() as base_dir:
-            lib_file = compile_module(code, code_hash_str, base_dir)
+            lib_file = compile_module(code, code_hash_str, base_dir, compile_flags=compile_flags)
             result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
     else:
-        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'])
+        lib_file = compile_module(code, code_hash_str, base_dir=cache_config['object_cache'],
+                                  compile_flags=compile_flags)
         result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
 
     return KernelWrapper(result, ast.get_parameters(), ast)
diff --git a/pystencils/include/arm_neon_helpers.h b/pystencils/include/arm_neon_helpers.h
index 3d06d69bf..a900001e7 100644
--- a/pystencils/include/arm_neon_helpers.h
+++ b/pystencils/include/arm_neon_helpers.h
@@ -1,5 +1,17 @@
+#ifdef __ARM_NEON
 #include <arm_neon.h>
+#endif
 
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
+#include <arm_sve.h>
+
+typedef svbool_t svbool_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+#endif
+
+#ifdef __ARM_NEON
 inline float32x4_t makeVec_f32(float a, float b, float c, float d)
 {
     alignas(16) float data[4] = {a, b, c, d};
@@ -17,6 +29,7 @@ inline int32x4_t makeVec_s32(int a, int b, int c, int d)
     alignas(16) int data[4] = {a, b, c, d};
     return vld1q_s32(data);
 }
+#endif
 
 inline void cachelineZero(void * p) {
 	__asm__ volatile("dc zva, %0"::"r"(p));
diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h
index 4d81d43e4..7684a4507 100644
--- a/pystencils/include/philox_rand.h
+++ b/pystencils/include/philox_rand.h
@@ -15,6 +15,14 @@
 #ifdef __ARM_NEON
 #include <arm_neon.h>
 #endif
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
+#include <arm_sve.h>
+typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svint32_t svint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svuint32_t svuint32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+typedef svuint64_t svuint64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
+#endif
 
 #if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && !defined(__xlC__)
 #include <ppu_intrinsics.h>
@@ -655,6 +663,158 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32
 }
 #endif
 
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
+QUALIFIERS void _philox4x32round(svuint32_st* ctr, svuint32_st* key)
+{
+    svuint32_st lo0 = svmul_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0));
+    svuint32_st lo1 = svmul_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1));
+    svuint32_st hi0 = svmulh_u32_x(svptrue_b32(), ctr[0], svdup_u32(PHILOX_M4x32_0));
+    svuint32_st hi1 = svmulh_u32_x(svptrue_b32(), ctr[2], svdup_u32(PHILOX_M4x32_1));
+
+    ctr[0] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, ctr[1]), key[0]);
+    ctr[1] = lo1;
+    ctr[2] = sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi0, ctr[3]), key[1]);
+    ctr[3] = lo0;
+}
+
+QUALIFIERS void _philox4x32bumpkey(svuint32_st* key)
+{
+    key[0] = svadd_u32_x(svptrue_b32(), key[0], svdup_u32(PHILOX_W32_0));
+    key[1] = svadd_u32_x(svptrue_b32(), key[1], svdup_u32(PHILOX_W32_1));
+}
+
+template<bool high>
+QUALIFIERS svfloat64_st _uniform_double_hq(svuint32_st x, svuint32_st y)
+{
+    // convert 32 to 64 bit
+    if (high)
+    {
+        x = svzip2_u32(x, svdup_u32(0));
+        y = svzip2_u32(y, svdup_u32(0));
+    }
+    else
+    {
+        x = svzip1_u32(x, svdup_u32(0));
+        y = svzip1_u32(y, svdup_u32(0));
+    }
+
+    // calculate z = x ^ y << (53 - 32))
+    svuint64_st z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
+    z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z);
+
+    // convert uint64 to double
+    svfloat64_st rs = svcvt_f64_u64_x(svptrue_b64(), z);
+    // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
+    rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0));
+
+    return rs;
+}
+
+
+QUALIFIERS void philox_float4(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3,
+                              uint32 key0, uint32 key1,
+                              svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
+{
+    svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)};
+    svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3};
+    _philox4x32round(ctr, key);                           // 1
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 2
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 3
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 4
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 5
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 6
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 7
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 8
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 9
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 10
+
+    // convert uint32 to float
+    rnd1 = svcvt_f32_u32_x(svptrue_b32(), ctr[0]);
+    rnd2 = svcvt_f32_u32_x(svptrue_b32(), ctr[1]);
+    rnd3 = svcvt_f32_u32_x(svptrue_b32(), ctr[2]);
+    rnd4 = svcvt_f32_u32_x(svptrue_b32(), ctr[3]);
+    // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
+    rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
+    rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
+    rnd3 = svmad_f32_x(svptrue_b32(), rnd3, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
+    rnd4 = svmad_f32_x(svptrue_b32(), rnd4, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
+}
+
+
+QUALIFIERS void philox_double2(svuint32_st ctr0, svuint32_st ctr1, svuint32_st ctr2, svuint32_st ctr3,
+                               uint32 key0, uint32 key1,
+                               svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
+{
+    svuint32_st key[2] = {svdup_u32(key0), svdup_u32(key1)};
+    svuint32_st ctr[4] = {ctr0, ctr1, ctr2, ctr3};
+    _philox4x32round(ctr, key);                           // 1
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 2
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 3
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 4
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 5
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 6
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 7
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 8
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 9
+    _philox4x32bumpkey(key); _philox4x32round(ctr, key);  // 10
+
+    rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
+    rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
+    rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
+    rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
+}
+
+QUALIFIERS void philox_float4(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3,
+                              uint32 key0, uint32 key1,
+                              svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
+{
+    svuint32_st ctr0v = svdup_u32(ctr0);
+    svuint32_st ctr2v = svdup_u32(ctr2);
+    svuint32_st ctr3v = svdup_u32(ctr3);
+
+    philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
+}
+
+QUALIFIERS void philox_float4(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3,
+                              uint32 key0, uint32 key1,
+                              svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
+{
+    philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4);
+}
+
+QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3,
+                               uint32 key0, uint32 key1,
+                               svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
+{
+    svuint32_st ctr0v = svdup_u32(ctr0);
+    svuint32_st ctr2v = svdup_u32(ctr2);
+    svuint32_st ctr3v = svdup_u32(ctr3);
+
+    philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
+}
+
+QUALIFIERS void philox_double2(uint32 ctr0, svuint32_st ctr1, uint32 ctr2, uint32 ctr3,
+                               uint32 key0, uint32 key1,
+                               svfloat64_st & rnd1, svfloat64_st & rnd2)
+{
+    svuint32_st ctr0v = svdup_u32(ctr0);
+    svuint32_st ctr2v = svdup_u32(ctr2);
+    svuint32_st ctr3v = svdup_u32(ctr3);
+
+    svfloat64_st ignore;
+    philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
+}
+
+QUALIFIERS void philox_double2(uint32 ctr0, svint32_st ctr1, uint32 ctr2, uint32 ctr3,
+                               uint32 key0, uint32 key1,
+                               svfloat64_st & rnd1, svfloat64_st & rnd2)
+{
+    philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2);
+}
+#endif
+
+
 #ifdef __AVX2__
 QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key)
 {
diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py
index b5396c9fb..cba58cfa2 100644
--- a/pystencils_tests/test_random.py
+++ b/pystencils_tests/test_random.py
@@ -27,7 +27,7 @@ if get_compiler_config()['os'] == 'windows':
 def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None):
     if target == 'gpu':
         pytest.importorskip('pycuda')
-    if instruction_sets and set(['neon', 'vsx']).intersection(instruction_sets) and rng == 'aesni':
+    if instruction_sets and (set(['neon', 'vsx']).intersection(instruction_sets) or any([iset.startswith('sve') for iset in instruction_sets])) and rng == 'aesni':
         pytest.xfail('AES not yet implemented for this architecture')
     if rng == 'aesni' and len(keys) == 2:
         keys *= 2
@@ -106,11 +106,11 @@ def test_rng_offsets(kind, vectorized):
     else:
         test = test_rng
     if kind == 'value':
-        test(instruction_sets[0] if vectorized else 'cpu', 'philox', 'float', 'float', t=8,
+        test(instruction_sets[-1] if vectorized else 'cpu', 'philox', 'float', 'float', t=8,
              offsets=(6, 7), keys=(5, 309))
     elif kind == 'symbol':
         offsets = (TypedSymbol("x0", np.uint32), TypedSymbol("y0", np.uint32))
-        test(instruction_sets[0] if vectorized else 'cpu', 'philox', 'float', 'float', t=8,
+        test(instruction_sets[-1] if vectorized else 'cpu', 'philox', 'float', 'float', t=8,
              offsets=offsets, offset_values=(6, 7), keys=(5, 309))
 
 
@@ -118,11 +118,11 @@ def test_rng_offsets(kind, vectorized):
 @pytest.mark.parametrize('rng', ('philox', 'aesni'))
 @pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double')))
 def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None):
-    if target in ['neon', 'vsx'] and rng == 'aesni':
+    if (target in ['neon', 'vsx'] or target.startswith('sve')) and rng == 'aesni':
         pytest.xfail('AES not yet implemented for this architecture')
     cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target}
 
-    dh = ps.create_data_handling((17, 17), default_ghost_layers=0, default_target='cpu')
+    dh = ps.create_data_handling((131, 131), default_ghost_layers=0, default_target='cpu')
     f = dh.add_array("f", values_per_cell=4 if precision == 'float' else 2,
                      dtype=np.float32 if dtype == 'float' else np.float64, alignment=True)
     dh.fill(f.name, 42.0)
@@ -157,7 +157,7 @@ def test_rng_symbol(vectorized):
             pytest.skip("cannot detect CPU instruction set")
         else:
             cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True,
-                                  'instruction_set': instruction_sets[0]}
+                                  'instruction_set': instruction_sets[-1]}
     else:
         cpu_vectorize_info = None
     
@@ -189,7 +189,7 @@ def test_staggered(vectorized):
         pytest.skip("cannot detect CPU instruction set")
     pytest.importorskip('islpy')
     cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': False,
-                          'instruction_set': instruction_sets[0]}
+                          'instruction_set': instruction_sets[-1]}
     
     dh.fill(j.name, 867)
     dh.run_kernel(kernel, seed=5, time_step=309)
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 880a009a2..f668c6b81 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -216,8 +216,8 @@ def test_logical_operators():
 
 
 def test_hardware_query():
-    instruction_sets = get_supported_instruction_sets()
-    assert set(['sse', 'neon', 'vsx']).intersection(instruction_sets)
+    assert set(['sse', 'neon', 'vsx']).intersection(supported_instruction_sets) or \
+           any([iset.startswith('sve') for iset in supported_instruction_sets])
 
 
 def test_vectorised_pow():
-- 
GitLab