From d314ef104f62fae0bf56bf00bb8b5b6824323d34 Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Sun, 2 May 2021 13:32:26 +0000
Subject: [PATCH] Vector scatter/gather support

---
 pystencils/backends/arm_instruction_sets.py   | 13 +++++++++-
 pystencils/backends/cbackend.py               | 12 +++++++--
 pystencils/backends/x86_instruction_sets.py   | 10 +++++++-
 pystencils/cpu/vectorization.py               | 18 +++++++------
 pystencils/data_types.py                      |  4 +--
 .../test_vectorization_specific.py            | 25 +++++++++++++++++++
 6 files changed, 68 insertions(+), 14 deletions(-)

diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py
index 9f7b4ee22..5318dffeb 100644
--- a/pystencils/backends/arm_instruction_sets.py
+++ b/pystencils/backends/arm_instruction_sets.py
@@ -88,9 +88,16 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['makeVecConstInt'] = f'svdup_s{bits["int"]}' + '({0})'
         result['makeVecIndex'] = f'svindex_s{bits["int"]}' + '({0}, {1})'
 
+        vindex = f'svindex_u{bits[data_type]}(0, {{0}})'
+        result['scatter'] = f'svst1_scatter_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                            vindex.format("{2}") + ', {1})'
+        result['gather'] = f'svld1_gather_u{bits[data_type]}index_f{bits[data_type]}({predicate}, {{0}}, ' + \
+                           vindex.format("{1}") + ')'
+
         result['+int'] = f"svadd_s{bits['int']}_x({int_predicate}, " + "{0}, {1})"
 
-        result[data_type] = f'svfloat{bits[data_type]}_st'
+        result['float'] = 'svfloat32_st'
+        result['double'] = 'svfloat64_st'
         result['int'] = f'svint{bits["int"]}_st'
         result['bool'] = 'svbool_st'
 
@@ -102,6 +109,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['any'] = f'svptest_any({predicate}, {{0}})'
         result['all'] = f'svcntp_b{bits[data_type]}({predicate}, {{0}}) == {width}'
 
+        result['maskStoreU'] = result['storeU'].replace(predicate, '{2}')
+        result['maskStoreA'] = result['storeA'].replace(predicate, '{2}')
+        result['maskScatter'] = result['scatter'].replace(predicate, '{3}')
+
         result['compile_flags'] = [f'-msve-vector-bits={bitwidth}']
     else:
         result['makeVecConst'] = f'vdupq_n_f{bits[data_type]}' + '({0})'
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 670588468..d11723c1d 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -256,7 +256,7 @@ class CBackend:
             lhs_type = get_type_of_expression(node.lhs)
             printed_mask = ""
             if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
-                arg, data_type, aligned, nontemporal, mask = node.lhs.args
+                arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
                 instr = 'storeU'
                 if aligned:
                     instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
@@ -285,6 +285,12 @@ class CBackend:
                     rhs = node.rhs
 
                 ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
+
+                if stride != 1:
+                    instr = 'maskScatter' if mask != True else 'scatter'  # NOQA
+                    return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
+                                                                      stride, printed_mask) + ';'
+
                 pre_code = ''
                 if nontemporal and 'cachelineZero' in self._vector_instruction_set:
                     pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \
@@ -609,7 +615,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
     def _print_Function(self, expr):
         if isinstance(expr, vector_memory_access):
-            arg, data_type, aligned, _, mask = expr.args
+            arg, data_type, aligned, _, mask, stride = expr.args
+            if stride != 1:
+                return self.instruction_set['gather'].format("& " + self._print(arg), stride)
             instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
             return instruction.format("& " + self._print(arg))
         elif isinstance(expr, cast_func):
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index 5cf049415..913db542f 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -130,7 +130,7 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
     result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
 
     if instruction_set == 'avx512':
-        size = 8 if data_type == 'double' else 16
+        size = result['width']
         result['&'] = f'_kand_mask{size}({{0}}, {{1}})'
         result['|'] = f'_kor_mask{size}({{0}}, {{1}})'
         result['any'] = f'!_ktestz_mask{size}_u8({{0}}, {{0}})'
@@ -145,6 +145,14 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         params = " | ".join(["({{0}} ? {power} : 0)".format(power=2 ** i) for i in range(8)])
         result['makeVecConstBool'] = f"__mmask8(({params}) )"
 
+        vindex = f'{pre}_set_epi{bit_width//size}(' + ', '.join([str(i) for i in range(result['width'])][::-1]) + ')'
+        vindex = f'{pre}_mullo_epi{bit_width//size}({vindex}, {pre}_set1_epi{bit_width//size}({{0}}))'
+        result['scatter'] = f'{pre}_i{bit_width//size}scatter_{suf}({{0}}, ' + vindex.format("{2}") + \
+                            f', {{1}}, {64//size})'
+        result['maskScatter'] = f'{pre}_mask_i{bit_width//size}scatter_{suf}({{0}}, {{3}}, ' + vindex.format("{2}") + \
+                                f', {{1}}, {64//size})'
+        result['gather'] = f'{pre}_i{bit_width//size}gather_{suf}(' + vindex.format("{1}") + f', {{0}}, {64//size})'
+
     if instruction_set == 'avx' and data_type == 'float':
         result['rsqrt'] = f"{pre}_rsqrt_{suf}({{0}})"
 
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 0de34b40b..16f0a1563 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -80,8 +80,9 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
     kernel_ast.instruction_set = vector_is
 
     vectorize_rng(kernel_ast, vector_width)
-    vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned,
-                                                nontemporal, assume_sufficient_line_padding)
+    scattergather = 'scatter' in vector_is and 'gather' in vector_is
+    vectorize_inner_loops_and_adapt_load_stores(kernel_ast, vector_width, assume_aligned, nontemporal,
+                                                scattergather, assume_sufficient_line_padding)
     insert_vector_casts(kernel_ast)
 
 
@@ -104,7 +105,7 @@ def vectorize_rng(kernel_ast, vector_width):
 
 
 def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
-                                                assume_sufficient_line_padding):
+                                                scattergather, assume_sufficient_line_padding):
     """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
     all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
     inner_loops = [n for n in all_loops if n.is_innermost_loop]
@@ -135,7 +136,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
             if loop_counter_symbol in index.atoms(sp.Symbol):
                 loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
                 aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
-                if not loop_counter_is_offset:
+                stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
+                if not loop_counter_is_offset and (not scattergather or loop_counter_symbol in stride.atoms()):
                     successful = False
                     break
                 typed_symbol = base.label
@@ -147,7 +149,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
                 nontemporal = False
                 if hasattr(indexed, 'field'):
                     nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
-                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True)
+                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
+                                                              stride if scattergather else 1)
                 if nontemporal:
                     # insert NontemporalFence after the outermost loop
                     parent = loop_node.parent
@@ -188,7 +191,7 @@ def mask_conditionals(loop_body):
                 node.condition_expr = vec_any(node.condition_expr)
         elif isinstance(node, ast.SympyAssignment):
             if mask is not True:
-                s = {ma: vector_memory_access(ma.args[0], ma.args[1], ma.args[2], ma.args[3], sp.And(mask, ma.args[4]))
+                s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
                      for ma in node.atoms(vector_memory_access)}
                 node.subs(s)
         else:
@@ -205,8 +208,7 @@ def insert_vector_casts(ast_node):
 
     def visit_expr(expr):
         if isinstance(expr, vector_memory_access):
-            return vector_memory_access(expr.args[0], expr.args[1], expr.args[2], expr.args[3],
-                                        visit_expr(expr.args[4]))
+            return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4]), *expr.args[5:])
         elif isinstance(expr, cast_func):
             return expr
         elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 46abd84f3..baf0a9674 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -195,8 +195,8 @@ class boolean_cast_func(cast_func, Boolean):
 
 # noinspection PyPep8Naming
 class vector_memory_access(cast_func):
-    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none)
-    nargs = (5,)
+    # Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
+    nargs = (6,)
 
 
 # noinspection PyPep8Naming
diff --git a/pystencils_tests/test_vectorization_specific.py b/pystencils_tests/test_vectorization_specific.py
index df0b9d943..f579b4e46 100644
--- a/pystencils_tests/test_vectorization_specific.py
+++ b/pystencils_tests/test_vectorization_specific.py
@@ -53,6 +53,31 @@ def test_vectorized_abs(instruction_set, dtype):
     np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3)
 
 
+@pytest.mark.parametrize('dtype', ('float', 'double'))
+@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
+def test_scatter_gather(instruction_set, dtype):
+    f, g = ps.fields(f"f, g : float{64 if dtype == 'double' else 32}[2D]")
+    update_rule = [ps.Assignment(g[0, 0], f[0, 0] + f[-1, 0] + f[1, 0] + f[0, 1] + f[0, -1] + 42.0)]
+    if 'scatter' not in get_vector_instruction_set(dtype, instruction_set) and not instruction_set == 'avx512' and not instruction_set.startswith('sve'):
+        with pytest.warns(UserWarning) as warn:
+            ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
+            assert 'Could not vectorize loop' in warn[0].message.args[0]
+    else:
+        with pytest.warns(None) as warn:
+            ast = ps.create_kernel(update_rule, cpu_vectorize_info={'instruction_set': instruction_set})
+            assert len(warn) == 0
+    func = ast.compile()
+    ref_func = ps.create_kernel(update_rule).compile()
+
+    arr = np.random.random((23 + 2, 17 + 2)).astype(np.float64 if dtype == 'double' else np.float32)
+    dst = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32)
+    ref = np.zeros_like(arr, dtype=np.float64 if dtype == 'double' else np.float32)
+
+    func(g=dst, f=arr)
+    ref_func(g=ref, f=arr)
+    np.testing.assert_almost_equal(dst, ref, 13 if dtype == 'double' else 5)
+
+
 @pytest.mark.parametrize('dtype', ('float', 'double'))
 @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
 @pytest.mark.parametrize('gl_field, gl_kernel', [(1, 0), (0, 1), (1, 1)])
-- 
GitLab