Commit d314ef10 authored by Michael Kuron's avatar Michael Kuron Committed by Markus Holzer
Browse files

Vector scatter/gather support

parent e9bd89c8
...@@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})' result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})' result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{2}") + ', {1})'
result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
vindex.format("{1}") + ')'
result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})" result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
result[data_type] = f'svfloat{bits[data_type]}_st' result['float'] = 'svfloat32_st'
result['double'] = 'svfloat64_st'
result['int'] = f'svint{bits["int"]}_st' result['int'] = f'svint{bits["int"]}_st'
result['bool'] = 'svbool_st' result['bool'] = 'svbool_st'
...@@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): ...@@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['any'] = f'svptest_any({predicate}, {{0}})' result['any'] = f'svptest_any({predicate}, {{0}})'
result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}' result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
result['compile_flags'] = [f'-msve-vector-bits={bitwidth}'] result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
else: else:
result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})' result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
......
...@@ -256,7 +256,7 @@ class CBackend: ...@@ -256,7 +256,7 @@ class CBackend:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
printed_mask = "" printed_mask = ""
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
arg, data_type, aligned, nontemporal, mask = node.lhs.args arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
instr = 'storeU' instr = 'storeU'
if aligned: if aligned:
instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA' instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
...@@ -285,6 +285,12 @@ class CBackend: ...@@ -285,6 +285,12 @@ class CBackend:
rhs = node.rhs rhs = node.rhs
ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
if stride != 1:
instr = 'maskScatter' if mask != True else 'scatter' # NOQA
return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
stride, printed_mask) + ';'
pre_code = '' pre_code = ''
if nontemporal and 'cachelineZero' in self._vector_instruction_set: if nontemporal and 'cachelineZero' in self._vector_instruction_set:
pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \ pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \
...@@ -609,7 +615,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -609,7 +615,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, vector_memory_access): if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _, mask = expr.args arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
return self.instruction_set['gather'].format("& " + self._print(arg), stride)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg)) return instruction.format("& " + self._print(arg))
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
......
...@@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}" result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
if instruction_set == 'avx512': if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16 size = result['width']
result['&'] = f'_kand_mask{size}({{0}}, {{1}})' result['&'] = f'_kand_mask{size}({{0}}, {{1}})'
result['|'] = f'_kor_mask{size}({{0}}, {{1}})' result['|'] = f'_kor_mask{size}({{0}}, {{1}})'
result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})' result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})'
...@@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)]) params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
result['makeVecConstBool'] = f"__mmask8(({params}) )" result['makeVecConstBool'] = f"__mmask8(({params}) )"
vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
f', {{1}}, {64//size})'
result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
if instruction_set == 'avx' and data_type == 'float': if instruction_set == 'avx' and data_type == 'float':
result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})" result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
......
...@@ -80,8 +80,9 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', ...@@ -80,8 +80,9 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
kernel_ast.instruction_set = vector_is kernel_ast.instruction_set = vector_is
vectorize_rng(kernel_ast, vector_width) vectorize_rng(kernel_ast, vector_width)
vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, scattergather = 'scatter' in vector_is and 'gather' in vector_is
nontemporal, assume_sufficient_line_padding) vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
scattergather, assume_sufficient_line_padding)
insert_vector_casts(kernel_ast) insert_vector_casts(kernel_ast)
...@@ -104,7 +105,7 @@ def vectorize_rng(kernel_ast, vector_width): ...@@ -104,7 +105,7 @@ def vectorize_rng(kernel_ast, vector_width):
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields, def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
assume_sufficient_line_padding): scattergather, assume_sufficient_line_padding):
"""Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type.""" """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment) all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
inner_loops = [n for n in all_loops if n.is_innermost_loop] inner_loops = [n for n in all_loops if n.is_innermost_loop]
...@@ -135,7 +136,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -135,7 +136,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
if loop_counter_symbol in index.atoms(sp.Symbol): if loop_counter_symbol in index.atoms(sp.Symbol):
loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms() loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0 aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
if not loop_counter_is_offset: stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()):
successful = False successful = False
break break
typed_symbol = base.label typed_symbol = base.label
...@@ -147,7 +149,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -147,7 +149,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
nontemporal = False nontemporal = False
if hasattr(indexed, 'field'): if hasattr(indexed, 'field'):
nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields) nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True) substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
stride if scattergather else 1)
if nontemporal: if nontemporal:
# insert NontemporalFence after the outermost loop # insert NontemporalFence after the outermost loop
parent = loop_node.parent parent = loop_node.parent
...@@ -188,7 +191,7 @@ def mask_conditionals(loop_body): ...@@ -188,7 +191,7 @@ def mask_conditionals(loop_body):
node.condition_expr = vec_any(node.condition_expr) node.condition_expr = vec_any(node.condition_expr)
elif isinstance(node, ast.SympyAssignment): elif isinstance(node, ast.SympyAssignment):
if mask is not True: if mask is not True:
s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4])) s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
for ma in node.atoms(vector_memory_access)} for ma in node.atoms(vector_memory_access)}
node.subs(s) node.subs(s)
else: else:
...@@ -205,8 +208,7 @@ def insert_vector_casts(ast_node): ...@@ -205,8 +208,7 @@ def insert_vector_casts(ast_node):
def visit_expr(expr): def visit_expr(expr):
if isinstance(expr, vector_memory_access): if isinstance(expr, vector_memory_access):
return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3], return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4]), *expr.args[5:])
visit_expr(expr.args[4]))
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
return expr return expr
elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
......
...@@ -195,8 +195,8 @@ class boolean_cast_func(cast_func, Boolean): ...@@ -195,8 +195,8 @@ class boolean_cast_func(cast_func, Boolean):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class vector_memory_access(cast_func): class vector_memory_access(cast_func):
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none) # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
nargs = (5,) nargs = (6,)
# noinspection PyPep8Naming # noinspection PyPep8Naming
......
...@@ -53,6 +53,31 @@ def test_vectorized_abs(instruction_set, dtype): ...@@ -53,6 +53,31 @@ def test_vectorized_abs(instruction_set, dtype):
np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3) np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
@pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_scatter_gather(instruction_set, dtype):
f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]")
update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set == 'avx512' and not instruction_set.startswith('sve'):
with pytest.warns(UserWarning) as warn:
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
assert 'Could not vectorize loop' in warn[0].message.args[0]
else:
with pytest.warns(None) as warn:
ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
assert len(warn) == 0
func = ast.compile()
ref_func = ps.create_kernel(update_rule).compile()
arr = np.random.random((23 + 2, 17 + 2)).astype(np.float64 if dtype == 'double' else np.float32)
dst = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32)
ref = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32)
func(g=dst, f=arr)
ref_func(g=ref, f=arr)
np.testing.assert_almost_equal(dst, ref, 13 if dtype == 'double' else 5)
@pytest.mark.parametrize('dtype', ('float', 'double')) @pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets) @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('gl_field, gl_kernel', [(1, 0), (0, 1), (1, 1)]) @pytest.mark.parametrize('gl_field, gl_kernel', [(1, 0), (0, 1), (1, 1)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment