Commit cdf73d8f authored by Michael Kuron's avatar Michael Kuron Committed by Jan Hönig
Browse files

Sizeless vectorization

parent cc645538
......@@ -178,7 +178,7 @@ arm64v9:
extends: .multiarch_template
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/arm64
variables:
PYSTENCILS_SIMD: "sve256,sve512"
PYSTENCILS_SIMD: "sve256,sve512,sve"
ASAN_OPTIONS: detect_leaks=0
LD_PRELOAD: /usr/lib/aarch64-linux-gnu/libasan.so.6
before_script:
......@@ -186,6 +186,20 @@ arm64v9:
- sed -i s/march=native/march=armv8-a+sve/g ~/.config/pystencils/config.json
- sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
riscv64:
# The RISC-V vector extension is still experimental and needs special compiler flags.
# Once they are officially released, this job should be cleaned up to match the others.
extends: .multiarch_template
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/riscv64
variables:
PYSTENCILS_SIMD: "rvv"
QEMU_CPU: "rv64,x-v=true"
before_script:
- *multiarch_before_script
- sed -i 's/march=native/march=rv64imfdv0p10 -menable-experimental-extensions/g' ~/.config/pystencils/config.json
- sed -i s/g\+\+/clang++/g ~/.config/pystencils/config.json
- sed -i 's/fopenmp/fopenmp=libgomp -I\/usr\/include\/riscv64-linux-gnu/g' ~/.config/pystencils/config.json
minimal-conda:
stage: test
except:
......
......@@ -28,13 +28,19 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
elif byte_alignment == 'cacheline':
cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets]
if all([s is None for s in cacheline_sizes]):
byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets])
widths = [get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets
if type(get_vector_instruction_set(type_name, is_name)['width']) is int]
byte_alignment = 64 if all([s is None for s in widths]) else max(widths)
else:
byte_alignment = max([s for s in cacheline_sizes if s is not None])
elif not any([type(get_vector_instruction_set(type_name, is_name)['width']) is int
for is_name in instruction_sets]):
byte_alignment = 64
else:
byte_alignment = max([get_vector_instruction_set(type_name, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets])
for is_name in instruction_sets
if type(get_vector_instruction_set(type_name, is_name)['width']) is int])
if (not align_inner_coordinate) or (not hasattr(shape, '__len__')):
size = np.prod(shape)
d = np.dtype(dtype)
......
......@@ -19,9 +19,8 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set != 'neon' and not instruction_set.startswith('sve'):
raise NotImplementedError(instruction_set)
if instruction_set == 'sve':
raise NotImplementedError("sizeless SVE is not implemented")
if instruction_set.startswith('sve'):
cmp = 'cmp'
elif instruction_set.startswith('sve'):
cmp = 'cmp'
bitwidth = int(instruction_set[3:])
elif instruction_set == 'neon':
......@@ -53,8 +52,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
'float': 32,
'int': 32}
width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int']
result = dict()
if instruction_set == 'sve':
width = 'svcntd()' if data_type == 'double' else 'svcntw()'
intwidth = 'svcntw()'
result['bytes'] = 'svcntb()'
else:
width = bitwidth // bits[data_type]
intwidth = bitwidth // bits['int']
result['bytes'] = bitwidth // 8
if instruction_set.startswith('sve'):
prefix = 'sv'
suffix = f'_f{bits[data_type]}'
......@@ -62,11 +69,12 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
prefix = 'v'
suffix = f'q_f{bits[data_type]}'
result = dict()
result['bytes'] = bitwidth // 8
predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
if instruction_set == 'sve':
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}_u64({{loop_counter}}, {{loop_stop}})'
else:
predicate = f'{prefix}whilelt_b{bits[data_type]}(0, {width})'
int_predicate = f'{prefix}whilelt_b{bits["int"]}(0, {intwidth})'
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
......@@ -80,8 +88,13 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result[intrinsic_id] = prefix + name + suffix + undef + arg_string
result['width'] = width
result['intwidth'] = intwidth
if instruction_set == 'sve':
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
else:
result['width'] = width
result['intwidth'] = intwidth
if instruction_set.startswith('sve'):
result['makeVecConst'] = f'svdup_f{bits[data_type]}' + '({0})'
......@@ -89,17 +102,17 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['storeS'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['loadS'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result['float'] = 'svfloat32_st'
result['double'] = 'svfloat64_st'
result['int'] = f'svint{bits["int"]}_st'
result['bool'] = 'svbool_st'
result['float'] = f'svfloat{bits["float"]}_{"s" if instruction_set != "sve" else ""}t'
result['double'] = f'svfloat{bits["double"]}_{"s" if instruction_set != "sve" else ""}t'
result['int'] = f'svint{bits["int"]}_{"s" if instruction_set != "sve" else ""}t'
result['bool'] = f'svbool_{"s" if instruction_set != "sve" else ""}t'
result['headers'] = ['<arm_sve.h>', '"arm_neon_helpers.h"']
......@@ -111,9 +124,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
if instruction_set != 'sve':
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
result['makeVec'] = f'makeVec_f{bits[data_type]}' + '(' + ", ".join(['{' + str(i) + '}' for i in
......@@ -137,7 +151,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
if bitwidth & (bitwidth - 1) == 0:
if instruction_set == 'sve' or bitwidth & (bitwidth - 1) == 0:
# only power-of-2 vector sizes will evenly divide a cacheline
result['cachelineSize'] = 'cachelineSize()'
result['cachelineZero'] = 'cachelineZero((void*) {0})'
......
......@@ -6,6 +6,7 @@ from typing import Set
import numpy as np
import sympy as sp
from sympy.core import S
from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
......@@ -165,6 +166,23 @@ class PrintNode(CustomCodeNode):
self.headers.append("<iostream>")
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
# ------------------------------------------- Printer ------------------------------------------------------------------
......@@ -184,6 +202,8 @@ class CBackend:
self._indent = " "
self._dialect = dialect
self._signatureOnly = signature_only
self._kwargs = {}
self.sympy_printer._kwargs = self._kwargs
def __call__(self, node):
prev_is = VectorType.instruction_set
......@@ -205,7 +225,8 @@ class CBackend:
return str(node)
def _print_KernelFunction(self, node):
function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()]
function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()
if not type(s.symbol) is CFunction]
launch_bounds = ""
if self._dialect == 'cuda':
max_threads = node.indexing.max_threads_per_block()
......@@ -232,6 +253,8 @@ class CBackend:
condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}"
update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}"
loop_str = f"for ({start}; {condition}; {update})"
self._kwargs['loop_counter'] = counter_symbol
self._kwargs['loop_stop'] = node.stop
prefix = "\n".join(node.prefix_lines)
if prefix:
......@@ -265,7 +288,8 @@ class CBackend:
if instr not in self._vector_instruction_set:
self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format(
'{0}', self._vector_instruction_set['blendv'].format(
self._vector_instruction_set['load' + instr[-1]].format('{0}'), '{1}', '{2}'))
self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
'{1}', '{2}', **self._kwargs), **self._kwargs)
printed_mask = self.sympy_printer.doprint(mask)
if data_type.base_type.base_name == 'double':
if self._vector_instruction_set['double'] == '__m256d':
......@@ -287,9 +311,9 @@ class CBackend:
ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
if stride != 1:
instr = 'maskScatter' if mask != True else 'scatter' # NOQA
instr = 'maskStoreS' if mask != True else 'storeS' # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask) + ';'
stride, printed_mask, **self._kwargs) + ';'
pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set:
......@@ -301,22 +325,22 @@ class CBackend:
element_size = 8 if data_type.base_type.base_name == 'double' else 4
size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n'
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
printed_mask) + ';'
printed_mask, **self._kwargs) + ';'
flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
if nontemporal and 'flushCacheline' in self._vector_instruction_set:
code2 = self._vector_instruction_set['flushCacheline'].format(
ptr, self.sympy_printer.doprint(rhs)) + ';'
ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
+ self.sympy_printer.doprint(rhs) + ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask) + ';'
code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask) \
+ ';'
code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask,
**self._kwargs) + ';'
code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
return pre_code + code
else:
......@@ -617,16 +641,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], vector_memory_access):
return self.instruction_set['abs'].format(self._print(expr.args[0]))
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr)
def _print_Function(self, expr):
if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
return self.instruction_set['gather'].format("& " + self._print(arg), stride)
return self.instruction_set['loadS'].format("& " + self._print(arg), stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg))
return instruction.format("& " + self._print(arg), **self._kwargs)
elif isinstance(expr, cast_func):
arg, data_type = expr.args
if type(data_type) is VectorType:
......@@ -640,19 +664,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
increments = np.array(arg)[1:] - np.array(arg)[:-1]
if len(set(increments)) == 1:
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0])
return self.instruction_set[instruction].format(*printed_args)
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
**self._kwargs)
return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg))
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr)
if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]),
**self._kwargs)
return result
elif expr.func == fast_sqrt:
return f"({self._print(sp.sqrt(expr.args[0]))})"
......@@ -660,7 +686,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Function', expr)
if not result:
if 'rsqrt' in self.instruction_set:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
......@@ -672,8 +698,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if isinstance(expr.args[0], sp.Rel):
op = expr.args[0].rel_op
if (instr, op) in self.instruction_set:
return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args])
return self.instruction_set[instr].format(self._print(expr.args[0]))
return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args],
**self._kwargs)
return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs)
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
......@@ -686,7 +713,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instruction_set['&'].format(result, item)
result = self.instruction_set['&'].format(result, item, **self._kwargs)
return result
def _print_Or(self, expr):
......@@ -698,7 +725,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(arg_strings) > 0
result = arg_strings[0]
for item in arg_strings[1:]:
result = self.instruction_set['|'].format(result, item)
result = self.instruction_set['|'].format(result, item, **self._kwargs)
return result
def _print_Add(self, expr, order=None):
......@@ -739,7 +766,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
processed = summands[0].term
for summand in summands[1:]:
func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
processed = func.format(processed, summand.term)
processed = func.format(processed, summand.term, **self._kwargs)
return processed
def _print_Pow(self, expr):
......@@ -747,21 +774,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
return result
one = self.instruction_set['makeVecConst'].format(1.0)
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp == -1:
one = self.instruction_set['makeVecConst'].format(1.0)
return self.instruction_set['/'].format(one, self._print(expr.base))
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
elif expr.exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base))
return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
elif expr.exp == -0.5:
root = self.instruction_set['sqrt'].format(self._print(expr.base))
return self.instruction_set['/'].format(one, root)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
return self.instruction_set['/'].format(one, root, **self._kwargs)
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return self.instruction_set['/'].format(one,
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)),
**self._kwargs)
else:
raise ValueError("Generic exponential not supported: " + str(expr))
......@@ -800,19 +828,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = a_str[0]
for item in a_str[1:]:
result = self.instruction_set['*'].format(result, item)
result = self.instruction_set['*'].format(result, item, **self._kwargs)
if len(b) > 0:
denominator_str = b_str[0]
for item in b_str[1:]:
denominator_str = self.instruction_set['*'].format(denominator_str, item)
result = self.instruction_set['/'].format(result, denominator_str)
denominator_str = self.instruction_set['*'].format(denominator_str, item, **self._kwargs)
result = self.instruction_set['/'].format(result, denominator_str, **self._kwargs)
if inside_add:
return sign, result
else:
if sign < 0:
return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
return self.instruction_set['*'].format(self._print(S.NegativeOne), result, **self._kwargs)
else:
return result
......@@ -820,13 +848,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Relational', expr)
if result:
return result
return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs)
def _print_Equality(self, expr):
result = self._scalarFallback('_print_Equality', expr)
if result:
return result
return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs), **self._kwargs)
def _print_Piecewise(self, expr):
result = self._scalarFallback('_print_Piecewise', expr)
......@@ -847,10 +875,11 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE:
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
result)
result, **self._kwargs)
else:
print("Warning - skipping ternary op")
else:
# noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition),
**self._kwargs)
return result
def get_argument_string(function_shortcut, last=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
if last:
arg_string += last + ','
arg_string = arg_string[:-1] + ")"
return arg_string
def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
assert instruction_set == 'rvv'
bits = {'double': 64,
'float': 32,
'int': 32}
base_names = {
'+': 'fadd_vv[0, 1]',
'-': 'fsub_vv[0, 1]',
'*': 'fmul_vv[0, 1]',
'/': 'fdiv_vv[0, 1]',
'sqrt': 'fsqrt_v[0]',
'loadU': f'le{bits[data_type]}_v[0]',
'loadA': f'le{bits[data_type]}_v[0]',
'storeU': f'se{bits[data_type]}_v[0, 1]',
'storeA': f'se{bits[data_type]}_v[0, 1]',
'maskStoreU': f'se{bits[data_type]}_v[2, 0, 1]',
'maskStoreA': f'se{bits[data_type]}_v[2, 0, 1]',
'loadS': f'lse{bits[data_type]}_v[0, 1]',
'storeS': f'sse{bits[data_type]}_v[0, 2, 1]',
'maskStoreS': f'sse{bits[data_type]}_v[2, 0, 3, 1]',
'abs': 'fabs_v[0]',
'==': 'mfeq_vv[0, 1]',
'!=': 'mfne_vv[0, 1]',
'<=': 'mfle_vv[0, 1]',
'<': 'mflt_vv[0, 1]',
'>=': 'mfge_vv[0, 1]',
'>': 'mfgt_vv[0, 1]',
'&': 'mand_mm[0, 1]',
'|': 'mor_mm[0, 1]',
'blendv': 'merge_vvm[2, 0, 1]',
'any': 'popc_m[0]',
'all': 'popc_m[0]',
}
result = dict()
width = f'vsetvlmax_e{bits[data_type]}m1()'
intwidth = 'vsetvlmax_e{bits["int"]}m1()'
result['bytes'] = 'vsetvlmax_e8m1()'
prefix = 'v'
suffix = f'_f{bits[data_type]}m1'
vl = '{loop_stop} - {loop_counter}'
int_vl = f'({vl})*{bits[data_type]//bits["int"]}'
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
if name.startswith('mf'):
suffix2 = suffix + f'_b{bits[data_type]}'
elif name.endswith('_mm') or name.endswith('_m'):
suffix2 = f'_b{bits[data_type]}'
elif intrinsic_id.startswith('mask'):
suffix2 = suffix + '_m'
else:
suffix2 = suffix
arg_string = get_argument_string(function_shortcut, last=vl)
result[intrinsic_id] = prefix + name + suffix2 + arg_string
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
result['makeVecConst'] = f'vfmv_v_f_f{bits[data_type]}m1({{0}}, {vl})'
result['makeVecConstInt'] = f'vmv_v_x_i{bits["int"]}m1({{0}}, {int_vl})'
result['makeVecIndex'] = f'vmacc_vx_i{bits["int"]}m1({result["makeVecConstInt"]}, {{1}}, ' + \
f'vid_v_i{bits["int"]}m1({int_vl}), {int_vl})'
result['storeS'] = result['storeS'].replace('{2}', f'{{2}}*{bits[data_type]//8}')
result['loadS'] = result['loadS'].replace('{1}', f'{{1}}*{bits[data_type]//8}')
result['maskStoreS'] = result['maskStoreS'].replace('{3}', f'{{3}}*{bits[data_type]//8}')
result['+int'] = f"vadd_vv_i{bits['int']}m1({{0}}, {{1}}, {int_vl})"
result['float'] = f'vfloat{bits["float"]}m1_t'
result['double'] = f'vfloat{bits["double"]}m1_t'
result['int'] = f'vint{bits["int"]}m1_t'
result['bool'] = f'vbool{bits[data_type]}_t'
result['headers'] = ['<riscv_vector.h>']
result['any'] += ' > 0x0'
result['all'] += f' == vsetvl_e{bits[data_type]}m1({vl})'
return result
......@@ -6,6 +6,7 @@ from ctypes import CDLL
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.arm_instruction_sets import get_vector_instruction_set_arm
from pystencils.backends.ppc_instruction_sets import get_vector_instruction_set_ppc
from pystencils.backends.riscv_instruction_sets import get_vector_instruction_set_riscv
def get_vector_instruction_set(data_type='double', instruction_set='avx'):
......@@ -13,6 +14,8 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):