diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e50894bd652481630f85cb899d57c973ea3c014a..0e3c6cf81485717403547fae80e9ba4cf03305b1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -175,7 +175,8 @@ arm64v9: extends: .multiarch_template image: i10git.cs.fau.de:5005/pycodegen/pycodegen/arm64 variables: - PYSTENCILS_SIMD: "sve256,sve512,sve" + PYSTENCILS_SIMD: "sve128,sve256,sve512,sve1024,sve2048,sve" + QEMU_CPU: "max,sve-default-vector-length=-1" before_script: - *multiarch_before_script - sed -i s/march=native/march=armv8-a+sve/g ~/.config/pystencils/config.json diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py index 73ea7eb4437eebd5b49d8ec19a0f55184eb12447..9aa8f6c0aeaf72f1c6d4a201d8df7435cce9f533 100644 --- a/pystencils/backends/arm_instruction_sets.py +++ b/pystencils/backends/arm_instruction_sets.py @@ -151,9 +151,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0' result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff' - if instruction_set == 'sve' or bitwidth & (bitwidth - 1) == 0: - # only power-of-2 vector sizes will evenly divide a cacheline - result['cachelineSize'] = 'cachelineSize()' - result['cachelineZero'] = 'cachelineZero((void*) {0})' + result['cachelineSize'] = 'cachelineSize()' + result['cachelineZero'] = 'cachelineZero((void*) {0})' return result diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index cdb2ee5cf16694be5499718b1d5275ddbc8a87dc..7d0d028c0691e48252a287dd81b46fd0d0a420cc 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -1,4 +1,3 @@ -import math import os import platform from ctypes import CDLL @@ -86,15 +85,12 @@ def get_supported_instruction_sets(): if flags.issuperset(required_sve_flags): if platform.system() == 'Linux': libc = CDLL('libc.so.6') - native_length = 8 * libc.prctl(51, 0, 0, 0, 0) # PR_SVE_GET_VL - if native_length < 0: + length = 8 * libc.prctl(51, 0, 0, 0, 0) # PR_SVE_GET_VL + if length < 0: raise OSError("SVE length query failed") - pwr2_length = int(2**math.floor(math.log2(native_length))) - if pwr2_length % 256 == 0: - result.append(f"sve{pwr2_length//2}") - if native_length != pwr2_length: - result.append(f"sve{pwr2_length}") - result.append(f"sve{native_length}") + while length > 128: + result.append(f"sve{length}") + length //= 2 result.append("sve") return result