From cdf73d8f0bde764ebf76cf702998a3bb4200b858 Mon Sep 17 00:00:00 2001 From: Michael Kuron <mkuron@icp.uni-stuttgart.de> Date: Fri, 21 May 2021 08:11:43 +0000 Subject: [PATCH] Sizeless vectorization --- .gitlab-ci.yml | 16 +- pystencils/alignedarray.py | 12 +- pystencils/backends/arm_instruction_sets.py | 60 ++++--- pystencils/backends/cbackend.py | 105 +++++++----- pystencils/backends/riscv_instruction_sets.py | 109 +++++++++++++ pystencils/backends/simd_instruction_sets.py | 11 +- pystencils/backends/x86_instruction_sets.py | 10 +- pystencils/cpu/cpujit.py | 6 +- pystencils/cpu/vectorization.py | 18 +- pystencils/data_types.py | 2 +- pystencils/include/philox_rand.h | 154 ++++++++++++++++++ pystencils_tests/test_conditional_vec.py | 31 +++- pystencils_tests/test_random.py | 4 +- pystencils_tests/test_vectorization.py | 3 +- .../test_vectorization_specific.py | 9 +- 15 files changed, 452 insertions(+), 98 deletions(-) create mode 100644 pystencils/backends/riscv_instruction_sets.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2ed64515b..196f25abd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -178,7 +178,7 @@ arm64v9: extends: .multiarch_template image: i10git.cs.fau.de:5005/pycodegen/pycodegen/arm64 variables: - PYSTENCILS_SIMD: "sve256,sve512" + PYSTENCILS_SIMD: "sve256,sve512,sve" ASAN_OPTIONS: detect_leaks=0 LD_PRELOAD: /usr/lib/aarch64-linux-gnu/libasan.so.6 before_script: @@ -186,6 +186,20 @@ arm64v9: - sed -i s/march=native/march=armv8-a+sve/g ~/.config/pystencils/config.json - sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json +riscv64: + # The RISC-V vector extension is still experimental and needs special compiler flags. + # Once they are officially released, this job should be cleaned up to match the others. + extends: .multiarch_template + image: i10git.cs.fau.de:5005/pycodegen/pycodegen/riscv64 + variables: + PYSTENCILS_SIMD: "rvv" + QEMU_CPU: "rv64,x-v=true" + before_script: + - *multiarch_before_script + - sed -i 's/march=native/march=rv64imfdv0p10 -menable-experimental-extensions/g' ~/.config/pystencils/config.json + - sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json + - sed -i 's/fopenmp/fopenmp=libgomp -I\/usr\/include\/riscv64-linux-gnu/g' ~/.config/pystencils/config.json + minimal-conda: stage: test except: diff --git a/pystencils/alignedarray.py b/pystencils/alignedarray.py index eda9fcaeb..da20a778e 100644 --- a/pystencils/alignedarray.py +++ b/pystencils/alignedarray.py @@ -28,13 +28,19 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o elif byte_alignment == 'cacheline': cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets] if all([s is None for s in cacheline_sizes]): - byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize - for is_name in instruction_sets]) + widths = [get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize + for is_name in instruction_sets + if type(get_vector_instruction_set(type_name, is_name)['width']) is int] + byte_alignment = 64 if all([s is None for s in widths]) else max(widths) else: byte_alignment = max([s for s in cacheline_sizes if s is not None]) + elif not any([type(get_vector_instruction_set(type_name, is_name)['width']) is int + for is_name in instruction_sets]): + byte_alignment = 64 else: byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize - for is_name in instruction_sets]) + for is_name in instruction_sets + if type(get_vector_instruction_set(type_name, is_name)['width']) is int]) if (not align_inner_coordinate) or (not hasattr(shape, '__len__')): size = np.prod(shape) d = np.dtype(dtype) diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py index 5318dffeb..73ea7eb44 100644 --- a/pystencils/backends/arm_instruction_sets.py +++ b/pystencils/backends/arm_instruction_sets.py @@ -19,9 +19,8 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): if instruction_set != 'neon' and not instruction_set.startswith('sve'): raise NotImplementedError(instruction_set) if instruction_set == 'sve': - raise NotImplementedError("sizeless SVE is not implemented") - - if instruction_set.startswith('sve'): + cmp = 'cmp' + elif instruction_set.startswith('sve'): cmp = 'cmp' bitwidth = int(instruction_set[3:]) elif instruction_set == 'neon': @@ -53,8 +52,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): 'float': 32, 'int': 32} - width = bitwidth // bits[data_type] - intwidth = bitwidth // bits['int'] + result = dict() + + if instruction_set == 'sve': + width = 'svcntd()' if data_type == 'double' else 'svcntw()' + intwidth = 'svcntw()' + result['bytes'] = 'svcntb()' + else: + width = bitwidth // bits[data_type] + intwidth = bitwidth // bits['int'] + result['bytes'] = bitwidth // 8 if instruction_set.startswith('sve'): prefix = 'sv' suffix = f'_f{bits[data_type]}' @@ -62,11 +69,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): prefix = 'v' suffix = f'q_f{bits[data_type]}' - result = dict() - result['bytes'] = bitwidth // 8 - - predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})' - int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})' + if instruction_set == 'sve': + predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})' + int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})' + else: + predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})' + int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})' for intrinsic_id, function_shortcut in base_names.items(): function_shortcut = function_shortcut.strip() @@ -80,8 +88,13 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result[intrinsic_id] = prefix + name + suffix + undef + arg_string - result['width'] = width - result['intwidth'] = intwidth + if instruction_set == 'sve': + from pystencils.backends.cbackend import CFunction + result['width'] = CFunction(width, "int") + result['intwidth'] = CFunction(intwidth, "int") + else: + result['width'] = width + result['intwidth'] = intwidth if instruction_set.startswith('sve'): result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})' @@ -89,17 +102,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' vindex = f'svindex_u{bits[data_type]}(0, {{0}})' - result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ - vindex.format("{2}") + ', {1})' - result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ - vindex.format("{1}") + ')' + result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ + vindex.format("{2}") + ', {1})' + result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \ + vindex.format("{1}") + ')' result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" - result['float'] = 'svfloat32_st' - result['double'] = 'svfloat64_st' - result['int'] = f'svint{bits["int"]}_st' - result['bool'] = 'svbool_st' + result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t' + result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t' + result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t' + result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t' result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"'] @@ -111,9 +124,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['maskStoreU'] = result['storeU'].replace(predicate, '{2}') result['maskStoreA'] = result['storeA'].replace(predicate, '{2}') - result['maskScatter'] = result['scatter'].replace(predicate, '{3}') + result['maskStoreS'] = result['storeS'].replace(predicate, '{3}') - result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] + if instruction_set != 'sve': + result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] else: result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in @@ -137,7 +151,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0' result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff' - if bitwidth & (bitwidth - 1) == 0: + if instruction_set == 'sve' or bitwidth & (bitwidth - 1) == 0: # only power-of-2 vector sizes will evenly divide a cacheline result['cachelineSize'] = 'cachelineSize()' result['cachelineZero'] = 'cachelineZero((void*) {0})' diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 8b0b13aa7..ba7426a14 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -6,6 +6,7 @@ from typing import Set import numpy as np import sympy as sp from sympy.core import S +from sympy.core.cache import cacheit from sympy.logic.boolalg import BooleanFalse, BooleanTrue from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node @@ -165,6 +166,23 @@ class PrintNode(CustomCodeNode): self.headers.append("<iostream>") +class CFunction(TypedSymbol): + def __new__(cls, function, dtype): + return CFunction.__xnew_cached_(cls, function, dtype) + + def __new_stage2__(cls, function, dtype): + return super(CFunction, cls).__xnew__(cls, function, dtype) + + __xnew__ = staticmethod(__new_stage2__) + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + def __getnewargs__(self): + return self.name, self.dtype + + def __getnewargs_ex__(self): + return (self.name, self.dtype), {} + + # ------------------------------------------- Printer ------------------------------------------------------------------ @@ -184,6 +202,8 @@ class CBackend: self._indent = " " self._dialect = dialect self._signatureOnly = signature_only + self._kwargs = {} + self.sympy_printer._kwargs = self._kwargs def __call__(self, node): prev_is = VectorType.instruction_set @@ -205,7 +225,8 @@ class CBackend: return str(node) def _print_KernelFunction(self, node): - function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()] + function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters() + if not type(s.symbol) is CFunction] launch_bounds = "" if self._dialect == 'cuda': max_threads = node.indexing.max_threads_per_block() @@ -232,6 +253,8 @@ class CBackend: condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}" update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}" loop_str = f"for ({start}; {condition}; {update})" + self._kwargs['loop_counter'] = counter_symbol + self._kwargs['loop_stop'] = node.stop prefix = "\n".join(node.prefix_lines) if prefix: @@ -265,7 +288,8 @@ class CBackend: if instr not in self._vector_instruction_set: self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format( '{0}', self._vector_instruction_set['blendv'].format( - self._vector_instruction_set['load' + instr[-1]].format('{0}'), '{1}', '{2}')) + self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), + '{1}', '{2}', **self._kwargs), **self._kwargs) printed_mask = self.sympy_printer.doprint(mask) if data_type.base_type.base_name == 'double': if self._vector_instruction_set['double'] == '__m256d': @@ -287,9 +311,9 @@ class CBackend: ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) if stride != 1: - instr = 'maskScatter' if mask != True else 'scatter' # NOQA + instr = 'maskStoreS' if mask != True else 'storeS' # NOQA return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), - stride, printed_mask) + ';' + stride, printed_mask, **self._kwargs) + ';' pre_code = '' if nontemporal and 'cachelineZero' in self._vector_instruction_set: @@ -301,22 +325,22 @@ class CBackend: element_size = 8 if data_type.base_type.base_name == 'double' else 4 size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}" pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ - self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n' + self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), - printed_mask) + ';' + printed_mask, **self._kwargs) + ';' flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}" if nontemporal and 'flushCacheline' in self._vector_instruction_set: code2 = self._vector_instruction_set['flushCacheline'].format( - ptr, self.sympy_printer.doprint(rhs)) + ';' + ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';' code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}" elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set: tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8] code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \ + self.sympy_printer.doprint(rhs) + ';' - code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask) + ';' - code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask) \ - + ';' + code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';' + code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask, + **self._kwargs) + ';' code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}" return pre_code + code else: @@ -617,16 +641,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): def _print_Abs(self, expr): if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access): - return self.instruction_set['abs'].format(self._print(expr.args[0])) + return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return super()._print_Abs(expr) def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _, mask, stride = expr.args if stride != 1: - return self.instruction_set['gather'].format("& " + self._print(arg), stride) + return self.instruction_set['loadS'].format("& " + self._print(arg), stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] - return instruction.format("& " + self._print(arg)) + return instruction.format("& " + self._print(arg), **self._kwargs) elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: @@ -640,19 +664,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: increments = np.array(arg)[1:] - np.array(arg)[:-1] if len(set(increments)) == 1: - return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0]) - return self.instruction_set[instruction].format(*printed_args) + return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0], + **self._kwargs) + return self.instruction_set[instruction].format(*printed_args, **self._kwargs) else: is_boolean = get_type_of_expression(arg) == create_type("bool") is_integer = get_type_of_expression(arg) == create_type("int") or \ (isinstance(arg, TypedSymbol) and arg.dtype.is_int()) instruction = 'makeVecConstBool' if is_boolean else \ 'makeVecConstInt' if is_integer else 'makeVecConst' - return self.instruction_set[instruction].format(self._print(arg)) + return self.instruction_set[instruction].format(self._print(arg), **self._kwargs) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: - result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) + result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]), + **self._kwargs) return result elif expr.func == fast_sqrt: return f"({self._print(sp.sqrt(expr.args[0]))})" @@ -660,7 +686,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._scalarFallback('_print_Function', expr) if not result: if 'rsqrt' in self.instruction_set: - return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) + return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs) else: return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, vec_any) or isinstance(expr, vec_all): @@ -672,8 +698,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if isinstance(expr.args[0], sp.Rel): op = expr.args[0].rel_op if (instr, op) in self.instruction_set: - return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args]) - return self.instruction_set[instr].format(self._print(expr.args[0])) + return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args], + **self._kwargs) + return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) @@ -686,7 +713,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): assert len(arg_strings) > 0 result = arg_strings[0] for item in arg_strings[1:]: - result = self.instruction_set['&'].format(result, item) + result = self.instruction_set['&'].format(result, item, **self._kwargs) return result def _print_Or(self, expr): @@ -698,7 +725,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): assert len(arg_strings) > 0 result = arg_strings[0] for item in arg_strings[1:]: - result = self.instruction_set['|'].format(result, item) + result = self.instruction_set['|'].format(result, item, **self._kwargs) return result def _print_Add(self, expr, order=None): @@ -739,7 +766,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): processed = summands[0].term for summand in summands[1:]: func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix] - processed = func.format(processed, summand.term) + processed = func.format(processed, summand.term, **self._kwargs) return processed def _print_Pow(self, expr): @@ -747,21 +774,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if result: return result - one = self.instruction_set['makeVecConst'].format(1.0) + one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" elif expr.exp == -1: - one = self.instruction_set['makeVecConst'].format(1.0) - return self.instruction_set['/'].format(one, self._print(expr.base)) + one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs) + return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs) elif expr.exp == 0.5: - return self.instruction_set['sqrt'].format(self._print(expr.base)) + return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) elif expr.exp == -0.5: - root = self.instruction_set['sqrt'].format(self._print(expr.base)) - return self.instruction_set['/'].format(one, root) + root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs) + return self.instruction_set['/'].format(one, root, **self._kwargs) elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: return self.instruction_set['/'].format(one, - self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) + self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)), + **self._kwargs) else: raise ValueError("Generic exponential not supported: " + str(expr)) @@ -800,19 +828,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = a_str[0] for item in a_str[1:]: - result = self.instruction_set['*'].format(result, item) + result = self.instruction_set['*'].format(result, item, **self._kwargs) if len(b) > 0: denominator_str = b_str[0] for item in b_str[1:]: - denominator_str = self.instruction_set['*'].format(denominator_str, item) - result = self.instruction_set['/'].format(result, denominator_str) + denominator_str = self.instruction_set['*'].format(denominator_str, item, **self._kwargs) + result = self.instruction_set['/'].format(result, denominator_str, **self._kwargs) if inside_add: return sign, result else: if sign < 0: - return self.instruction_set['*'].format(self._print(S.NegativeOne), result) + return self.instruction_set['*'].format(self._print(S.NegativeOne), result, **self._kwargs) else: return result @@ -820,13 +848,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._scalarFallback('_print_Relational', expr) if result: return result - return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) + return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs) def _print_Equality(self, expr): result = self._scalarFallback('_print_Equality', expr) if result: return result - return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs)) + return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs) def _print_Piecewise(self, expr): result = self._scalarFallback('_print_Piecewise', expr) @@ -847,10 +875,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): if not KERNCRAFT_NO_TERNARY_MODE: result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), - result) + result, **self._kwargs) else: print("Warning - skipping ternary op") else: # noinspection SpellCheckingInspection - result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) + result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition), + **self._kwargs) return result diff --git a/pystencils/backends/riscv_instruction_sets.py b/pystencils/backends/riscv_instruction_sets.py new file mode 100644 index 000000000..d93aee701 --- /dev/null +++ b/pystencils/backends/riscv_instruction_sets.py @@ -0,0 +1,109 @@ +def get_argument_string(function_shortcut, last=''): + args = function_shortcut[function_shortcut.index('[') + 1: -1] + arg_string = "(" + for arg in args.split(","): + arg = arg.strip() + if not arg: + continue + if arg in ('0', '1', '2', '3', '4', '5'): + arg_string += "{" + arg + "}," + else: + arg_string += arg + "," + if last: + arg_string += last + ',' + arg_string = arg_string[:-1] + ")" + return arg_string + + +def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): + assert instruction_set == 'rvv' + + bits = {'double': 64, + 'float': 32, + 'int': 32} + + base_names = { + '+': 'fadd_vv[0, 1]', + '-': 'fsub_vv[0, 1]', + '*': 'fmul_vv[0, 1]', + '/': 'fdiv_vv[0, 1]', + 'sqrt': 'fsqrt_v[0]', + + 'loadU': f'le{bits[data_type]}_v[0]', + 'loadA': f'le{bits[data_type]}_v[0]', + 'storeU': f'se{bits[data_type]}_v[0, 1]', + 'storeA': f'se{bits[data_type]}_v[0, 1]', + 'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]', + 'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]', + 'loadS': f'lse{bits[data_type]}_v[0, 1]', + 'storeS': f'sse{bits[data_type]}_v[0, 2, 1]', + 'maskStoreS': f'sse{bits[data_type]}_v[2, 0, 3, 1]', + + 'abs': 'fabs_v[0]', + '==': 'mfeq_vv[0, 1]', + '!=': 'mfne_vv[0, 1]', + '<=': 'mfle_vv[0, 1]', + '<': 'mflt_vv[0, 1]', + '>=': 'mfge_vv[0, 1]', + '>': 'mfgt_vv[0, 1]', + '&': 'mand_mm[0, 1]', + '|': 'mor_mm[0, 1]', + + 'blendv': 'merge_vvm[2, 0, 1]', + 'any': 'popc_m[0]', + 'all': 'popc_m[0]', + } + + result = dict() + + width = f'vsetvlmax_e{bits[data_type]}m1()' + intwidth = 'vsetvlmax_e{bits["int"]}m1()' + result['bytes'] = 'vsetvlmax_e8m1()' + prefix = 'v' + suffix = f'_f{bits[data_type]}m1' + + vl = '{loop_stop} - {loop_counter}' + int_vl = f'({vl})*{bits[data_type]//bits["int"]}' + + for intrinsic_id, function_shortcut in base_names.items(): + function_shortcut = function_shortcut.strip() + name = function_shortcut[:function_shortcut.index('[')] + if name.startswith('mf'): + suffix2 = suffix + f'_b{bits[data_type]}' + elif name.endswith('_mm') or name.endswith('_m'): + suffix2 = f'_b{bits[data_type]}' + elif intrinsic_id.startswith('mask'): + suffix2 = suffix + '_m' + else: + suffix2 = suffix + + arg_string = get_argument_string(function_shortcut, last=vl) + + result[intrinsic_id] = prefix + name + suffix2 + arg_string + + from pystencils.backends.cbackend import CFunction + result['width'] = CFunction(width, "int") + result['intwidth'] = CFunction(intwidth, "int") + + result['makeVecConst'] = f'vfmv_v_f_f{bits[data_type]}m1({{0}}, {vl})' + result['makeVecConstInt'] = f'vmv_v_x_i{bits["int"]}m1({{0}}, {int_vl})' + result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \ + f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})' + + result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}') + result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}') + result['maskStoreS'] = result['maskStoreS'].replace('{3}', f'{{3}}*{bits[data_type]//8}') + + result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})" + + result['float'] = f'vfloat{bits["float"]}m1_t' + result['double'] = f'vfloat{bits["double"]}m1_t' + result['int'] = f'vint{bits["int"]}m1_t' + result['bool'] = f'vbool{bits[data_type]}_t' + + result['headers'] = ['<riscv_vector.h>'] + + result['any'] += ' > 0x0' + result['all'] += f' == vsetvl_e{bits[data_type]}m1({vl})' + + return result diff --git a/pystencils/backends/simd_instruction_sets.py b/pystencils/backends/simd_instruction_sets.py index 4fe147821..8ac0beeb7 100644 --- a/pystencils/backends/simd_instruction_sets.py +++ b/pystencils/backends/simd_instruction_sets.py @@ -6,6 +6,7 @@ from ctypes import CDLL from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86 from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ppc +from pystencils.backends.riscv_instruction_sets import get_vector_instruction_set_riscv def get_vector_instruction_set(data_type='double', instruction_set='avx'): @@ -13,6 +14,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'): return get_vector_instruction_set_arm(data_type, instruction_set) elif instruction_set in ['vsx']: return get_vector_instruction_set_ppc(data_type, instruction_set) + elif instruction_set in ['rvv']: + return get_vector_instruction_set_riscv(data_type, instruction_set) else: return get_vector_instruction_set_x86(data_type, instruction_set) @@ -30,6 +33,11 @@ def get_supported_instruction_sets(): return os.environ['PYSTENCILS_SIMD'].split(',') if platform.system() == 'Darwin' and platform.machine() == 'arm64': # not supported by cpuinfo return ['neon'] + elif platform.system() == 'Linux' and platform.machine().startswith('riscv'): # not supported by cpuinfo + libc = CDLL('libc.so.6') + hwcap = libc.getauxval(16) # AT_HWCAP + hwcap_isa_v = 1 << (ord('V') - ord('A')) # COMPAT_HWCAP_ISA_V + return ['rvv'] if hwcap & hwcap_isa_v else [] elif platform.machine().startswith('ppc64'): # no flags reported by cpuinfo import subprocess import tempfile @@ -74,8 +82,7 @@ def get_supported_instruction_sets(): if native_length != pwr2_length: result.append(f"sve{pwr2_length}") result.append(f"sve{native_length}") - else: - result.append("sve") + result.append("sve") return result diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py index 913db542f..f72b48266 100644 --- a/pystencils/backends/x86_instruction_sets.py +++ b/pystencils/backends/x86_instruction_sets.py @@ -147,11 +147,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')' vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))' - result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \ - f', {{1}}, {64//size})' - result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \ - f', {{1}}, {64//size})' - result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})' + result['storeS'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \ + f', {{1}}, {64//size})' + result['maskStoreS'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \ + f', {{1}}, {64//size})' + result['loadS'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})' if instruction_set == 'avx' and data_type == 'float': result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index dfb33e74b..e0b9e5612 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -59,7 +59,7 @@ from appdirs import user_cache_dir, user_config_dir from pystencils import FieldType from pystencils.astnodes import LoopOverCoordinate -from pystencils.backends.cbackend import generate_c, get_headers +from pystencils.backends.cbackend import generate_c, get_headers, CFunction from pystencils.data_types import cast_func, VectorType, vector_memory_access from pystencils.include import get_pystencils_include_path from pystencils.kernel_wrapper import KernelWrapper @@ -411,7 +411,7 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec if has_openmp and has_nontemporal: byte_width = ast_node.instruction_set['cachelineSize'] offset = max(max(ast_node.ghost_layers)) * item_size - offset_cond = f"(((uintptr_t) buffer_{field.name}.buf) + {offset}) % {byte_width} == 0" + offset_cond = f"(((uintptr_t) buffer_{field.name}.buf) + {offset}) % ({byte_width}) == 0" message = str(offset) + ". This is probably due to a different number of ghost_layers chosen for " \ "the arrays and the kernel creation. If the number of ghost layers for " \ @@ -460,6 +460,8 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec name=field.name)) elif param.is_field_shape: parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]") + elif type(param.symbol) is CFunction: + continue else: extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type] if np.issubdtype(param.symbol.dtype.numpy_dtype, np.complexfloating): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 6ab821f4e..b9fa2819e 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -127,9 +127,10 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', kernel_ast.instruction_set = vector_is vectorize_rng(kernel_ast, vector_width) - scattergather = 'scatter' in vector_is and 'gather' in vector_is + strided = 'storeS' in vector_is and 'loadS' in vector_is + keep_loop_stop = '{loop_stop}' in vector_is['storeA' if assume_aligned else 'storeU'] vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal, - scattergather, assume_sufficient_line_padding) + strided, keep_loop_stop, assume_sufficient_line_padding) insert_vector_casts(kernel_ast) @@ -152,7 +153,7 @@ def vectorize_rng(kernel_ast, vector_width): def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, - scattergather, assume_sufficient_line_padding): + strided, keep_loop_stop, assume_sufficient_line_padding): """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) inner_loops = [n for n in all_loops if n.is_innermost_loop] @@ -162,7 +163,9 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_range = loop_node.stop - loop_node.start # cut off loop tail, that is not a multiple of four - if assume_aligned and assume_sufficient_line_padding: + if keep_loop_stop: + pass + elif assume_aligned and assume_sufficient_line_padding: loop_range = loop_node.stop - loop_node.start new_stop = loop_node.start + modulo_ceil(loop_range, vector_width) loop_node.stop = new_stop @@ -184,7 +187,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0 stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index) - if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()): + if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()): successful = False break typed_symbol = base.label @@ -197,7 +200,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a if hasattr(indexed, 'field'): nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True, - stride if scattergather else 1) + stride if strided else 1) if nontemporal: # insert NontemporalFence after the outermost loop parent = loop_node.parent @@ -214,7 +217,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a loop_node.subs(substitutions) vector_int_width = ast_node.instruction_set['intwidth'] vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \ - + cast_func(tuple(range(vector_int_width)), VectorType(loop_counter_symbol.dtype, vector_int_width)) + + cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)), + VectorType(loop_counter_symbol.dtype, vector_int_width)) fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access)) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index baf0a9674..f0e10d4a0 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -712,7 +712,7 @@ class VectorType(Type): def __str__(self): if self.instruction_set is None: - return "%s[%d]" % (self.base_type, self.width) + return "%s[%s]" % (self.base_type, self.width) else: if self.base_type == create_type("int64") or self.base_type == create_type("int32"): return self.instruction_set['int'] diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index 84f0ba91e..eca71a200 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -30,6 +30,10 @@ #endif #endif +#ifdef __riscv_v +#include <riscv_vector.h> +#endif + #ifndef __CUDA_ARCH__ #define QUALIFIERS inline #include "myintrin.h" @@ -818,6 +822,156 @@ QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 #endif +#if defined(__riscv_v) +QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32m1_t & ctr2, vuint32m1_t & ctr3, + vuint32m1_t key0, vuint32m1_t key1) +{ + vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + + ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1()); + ctr1 = lo1; + ctr2 = vxor_vv_u32m1(vxor_vv_u32m1(hi0, ctr3, vsetvlmax_e32m1()), key1, vsetvlmax_e32m1()); + ctr3 = lo0; +} + +QUALIFIERS void _philox4x32bumpkey(vuint32m1_t & key0, vuint32m1_t & key1) +{ + key0 = vadd_vv_u32m1(key0, vmv_v_x_u32m1(PHILOX_W32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + key1 = vadd_vv_u32m1(key1, vmv_v_x_u32m1(PHILOX_W32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); +} + +template<bool high> +QUALIFIERS vfloat64m1_t _uniform_double_hq(vuint32m1_t x, vuint32m1_t y) +{ + // convert 32 to 64 bit + if (high) + { + size_t s = vsetvlmax_e32m1(); + x = vslidedown_vx_u32m1(vundefined_u32m1(), x, s/2, s); + y = vslidedown_vx_u32m1(vundefined_u32m1(), y, s/2, s); + } + vuint64m1_t x64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(x), vsetvlmax_e64m1()); + vuint64m1_t y64 = vwcvtu_x_x_v_u64m1(vlmul_trunc_v_u32m1_u32mf2(y), vsetvlmax_e64m1()); + + // calculate z = x ^ y << (53 - 32)) + vuint64m1_t z = vsll_vx_u64m1(y64, 53 - 32, vsetvlmax_e64m1()); + z = vxor_vv_u64m1(x64, z, vsetvlmax_e64m1()); + + // convert uint64 to double + vfloat64m1_t rs = vfcvt_f_xu_v_f64m1(z, vsetvlmax_e64m1()); + // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0) + rs = vfmadd_vv_f64m1(rs, vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE, vsetvlmax_e64m1()), vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE/2.0, vsetvlmax_e64m1()), vsetvlmax_e64m1()); + + return rs; +} + + +QUALIFIERS void philox_float4(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3, + uint32 key0, uint32 key1, + vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4) +{ + vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1()); + vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1()); + _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10 + + // convert uint32 to float + rnd1 = vfcvt_f_xu_v_f32m1(ctr0, vsetvlmax_e32m1()); + rnd2 = vfcvt_f_xu_v_f32m1(ctr1, vsetvlmax_e32m1()); + rnd3 = vfcvt_f_xu_v_f32m1(ctr2, vsetvlmax_e32m1()); + rnd4 = vfcvt_f_xu_v_f32m1(ctr3, vsetvlmax_e32m1()); + // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f) + rnd1 = vfmadd_vv_f32m1(rnd1, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + rnd2 = vfmadd_vv_f32m1(rnd2, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + rnd3 = vfmadd_vv_f32m1(rnd3, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + rnd4 = vfmadd_vv_f32m1(rnd4, vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, vsetvlmax_e32m1()), vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT/2.0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); +} + + +QUALIFIERS void philox_double2(vuint32m1_t ctr0, vuint32m1_t ctr1, vuint32m1_t ctr2, vuint32m1_t ctr3, + uint32 key0, uint32 key1, + vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi) +{ + vuint32m1_t key0v = vmv_v_x_u32m1(key0, vsetvlmax_e32m1()); + vuint32m1_t key1v = vmv_v_x_u32m1(key1, vsetvlmax_e32m1()); + _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9 + _philox4x32bumpkey(key0v, key1v); _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10 + + rnd1lo = _uniform_double_hq<false>(ctr0, ctr1); + rnd1hi = _uniform_double_hq<true>(ctr0, ctr1); + rnd2lo = _uniform_double_hq<false>(ctr2, ctr3); + rnd2hi = _uniform_double_hq<true>(ctr2, ctr3); +} + +QUALIFIERS void philox_float4(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4) +{ + vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1()); + vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1()); + vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1()); + + philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_float4(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + vfloat32m1_t & rnd1, vfloat32m1_t & rnd2, vfloat32m1_t & rnd3, vfloat32m1_t & rnd4) +{ + philox_float4(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2, rnd3, rnd4); +} + +QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + vfloat64m1_t & rnd1lo, vfloat64m1_t & rnd1hi, vfloat64m1_t & rnd2lo, vfloat64m1_t & rnd2hi) +{ + vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1()); + vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1()); + vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1()); + + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi); +} + +QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + vfloat64m1_t & rnd1, vfloat64m1_t & rnd2) +{ + vuint32m1_t ctr0v = vmv_v_x_u32m1(ctr0, vsetvlmax_e32m1()); + vuint32m1_t ctr2v = vmv_v_x_u32m1(ctr2, vsetvlmax_e32m1()); + vuint32m1_t ctr3v = vmv_v_x_u32m1(ctr3, vsetvlmax_e32m1()); + + vfloat64m1_t ignore; + philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore); +} + +QUALIFIERS void philox_double2(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2, uint32 ctr3, + uint32 key0, uint32 key1, + vfloat64m1_t & rnd1, vfloat64m1_t & rnd2) +{ + philox_double2(ctr0, vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3, key0, key1, rnd1, rnd2); +} +#endif + + #ifdef __AVX2__ QUALIFIERS void _philox4x32round(__m256i* ctr, __m256i* key) { diff --git a/pystencils_tests/test_conditional_vec.py b/pystencils_tests/test_conditional_vec.py index 959e20c2b..c67075418 100644 --- a/pystencils_tests/test_conditional_vec.py +++ b/pystencils_tests/test_conditional_vec.py @@ -12,7 +12,10 @@ supported_instruction_sets = get_supported_instruction_sets() if get_supported_i @pytest.mark.parametrize('instruction_set', supported_instruction_sets) @pytest.mark.parametrize('dtype', ('float', 'double')) def test_vec_any(instruction_set, dtype): - width = get_vector_instruction_set(dtype, instruction_set)['width'] + if instruction_set in ['sve', 'rvv']: + width = 4 # we don't know the actual value + else: + width = get_vector_instruction_set(dtype, instruction_set)['width'] data_arr = np.zeros((4*width, 4*width), dtype=np.float64 if dtype == 'double' else np.float32) data_arr[3:9, 1:3*width-1] = 1.0 @@ -28,13 +31,20 @@ def test_vec_any(instruction_set, dtype): cpu_vectorize_info={'instruction_set': instruction_set}) kernel = ast.compile() kernel(data=data_arr) - np.testing.assert_equal(data_arr[3:9, :3*width], 2.0) + if instruction_set in ['sve', 'rvv']: + # we only know that the first value has changed + np.testing.assert_equal(data_arr[3:9, :3*width-1], 2.0) + else: + np.testing.assert_equal(data_arr[3:9, :3*width], 2.0) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) @pytest.mark.parametrize('dtype', ('float', 'double')) def test_vec_all(instruction_set, dtype): - width = get_vector_instruction_set(dtype, instruction_set)['width'] + if instruction_set in ['sve', 'rvv']: + width = 1000 # we don't know the actual value, need something guaranteed larger than vector + else: + width = get_vector_instruction_set(dtype, instruction_set)['width'] data_arr = np.zeros((4*width, 4*width), dtype=np.float64 if dtype == 'double' else np.float32) data_arr[3:9, 1:3*width-1] = 1.0 @@ -49,11 +59,16 @@ def test_vec_all(instruction_set, dtype): cpu_vectorize_info={'instruction_set': instruction_set}) kernel = ast.compile() kernel(data=data_arr) - np.testing.assert_equal(data_arr[3:9, :1], 0.0) - np.testing.assert_equal(data_arr[3:9, 1:width], 1.0) - np.testing.assert_equal(data_arr[3:9, width:2*width], 2.0) - np.testing.assert_equal(data_arr[3:9, 2*width:3*width-1], 1.0) - np.testing.assert_equal(data_arr[3:9, 3*width-1:], 0.0) + if instruction_set in ['sve', 'rvv']: + # we only know that some values in the middle have been replaced + assert np.all(data_arr[3:9, :2] <= 1.0) + assert np.any(data_arr[3:9, 2:] == 2.0) + else: + np.testing.assert_equal(data_arr[3:9, :1], 0.0) + np.testing.assert_equal(data_arr[3:9, 1:width], 1.0) + np.testing.assert_equal(data_arr[3:9, width:2*width], 2.0) + np.testing.assert_equal(data_arr[3:9, 2*width:3*width-1], 1.0) + np.testing.assert_equal(data_arr[3:9, 3*width-1:], 0.0) @pytest.mark.skipif(not supported_instruction_sets, reason='cannot detect CPU instruction set') diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index cba58cfa2..18ff23b7b 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -27,7 +27,7 @@ if get_compiler_config()['os'] == 'windows': def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None): if target == 'gpu': pytest.importorskip('pycuda') - if instruction_sets and (set(['neon', 'vsx']).intersection(instruction_sets) or any([iset.startswith('sve') for iset in instruction_sets])) and rng == 'aesni': + if instruction_sets and set(['neon', 'sve', 'vsx', 'rvv']).intersection(instruction_sets) and rng == 'aesni': pytest.xfail('AES not yet implemented for this architecture') if rng == 'aesni' and len(keys) == 2: keys *= 2 @@ -118,7 +118,7 @@ def test_rng_offsets(kind, vectorized): @pytest.mark.parametrize('rng', ('philox', 'aesni')) @pytest.mark.parametrize('precision,dtype', (('float', 'float'), ('double', 'double'))) def test_rng_vectorized(target, rng, precision, dtype, t=130, offsets=(1, 3), keys=(0, 0), offset_values=None): - if (target in ['neon', 'vsx'] or target.startswith('sve')) and rng == 'aesni': + if (target in ['neon', 'vsx', 'rvv'] or target.startswith('sve')) and rng == 'aesni': pytest.xfail('AES not yet implemented for this architecture') cpu_vectorize_info = {'assume_inner_stride_one': True, 'assume_aligned': True, 'instruction_set': target} diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index b7ee2e83b..00618a061 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -216,8 +216,7 @@ def test_logical_operators(instruction_set=instruction_set): def test_hardware_query(): - assert set(['sse', 'neon', 'vsx']).intersection(supported_instruction_sets) or \ - any([iset.startswith('sve') for iset in supported_instruction_sets]) + assert set(['sse', 'neon', 'sve', 'vsx', 'rvv']).intersection(supported_instruction_sets) def test_vectorised_pow(instruction_set=instruction_set): diff --git a/pystencils_tests/test_vectorization_specific.py b/pystencils_tests/test_vectorization_specific.py index 16780f147..1c0c35e53 100644 --- a/pystencils_tests/test_vectorization_specific.py +++ b/pystencils_tests/test_vectorization_specific.py @@ -55,10 +55,10 @@ def test_vectorized_abs(instruction_set, dtype): @pytest.mark.parametrize('dtype', ('float', 'double')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) -def test_scatter_gather(instruction_set, dtype): +def test_strided(instruction_set, dtype): f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]") update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)] - if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set == 'avx512' and not instruction_set.startswith('sve'): + if 'storeS' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set in ['avx512', 'rvv'] and not instruction_set.startswith('sve'): with pytest.warns(UserWarning) as warn: ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set}) assert 'Could not vectorize loop' in warn[0].message.args[0] @@ -106,12 +106,13 @@ def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set @pytest.mark.parametrize('instruction_set', supported_instruction_sets) def test_cacheline_size(instruction_set): cacheline_size = get_cacheline_size(instruction_set) - if cacheline_size is None: + if cacheline_size is None and instruction_set in ['sse', 'avx', 'avx512', 'rvv']: pytest.skip() instruction_set = get_vector_instruction_set('double', instruction_set) vector_size = instruction_set['bytes'] assert cacheline_size > 8 and cacheline_size < 0x100000, "Cache line size is implausible" - assert cacheline_size % vector_size == 0, "Cache line size should be multiple of vector size" + if type(vector_size) is int: + assert cacheline_size % vector_size == 0, "Cache line size should be multiple of vector size" assert cacheline_size & (cacheline_size - 1) == 0, "Cache line size is not a power of 2" -- GitLab