From b1750b81902657db2d12201f32824b87106d3091 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Sun, 28 Mar 2021 16:05:06 +0200
Subject: [PATCH] add fence after non-temporal stores

Fixes #25
---
 pystencils/astnodes.py                      | 42 ++++++++++++++++++++-
 pystencils/backends/cbackend.py             |  6 +++
 pystencils/backends/x86_instruction_sets.py |  2 +
 pystencils/cpu/vectorization.py             |  5 +++
 pystencils_tests/test_vectorization.py      |  2 +
 5 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index b874db9b0..74ca259da 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -324,7 +324,7 @@ class Block(Node):
             node.parent = self
             self._nodes.insert(0, node)
 
-    def insert_before(self, new_node, insert_before):
+    def insert_before(self, new_node, insert_before, if_not_exists=False):
         new_node.parent = self
         assert self._nodes.count(insert_before) == 1
         idx = self._nodes.index(insert_before)
@@ -337,7 +337,25 @@ class Block(Node):
                     idx -= 1
                 else:
                     break
-        self._nodes.insert(idx, new_node)
+        if not if_not_exists or self._nodes[idx] != new_node:
+            self._nodes.insert(idx, new_node)
+
+    def insert_after(self, new_node, insert_after, if_not_exists=False):
+        new_node.parent = self
+        assert self._nodes.count(insert_after) == 1
+        idx = self._nodes.index(insert_after) + 1
+
+        # move all assignment (definitions to the top)
+        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
+            while idx > 0:
+                pn = self._nodes[idx - 1]
+                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
+                    idx -= 1
+                else:
+                    break
+        if not if_not_exists or not (self._nodes[idx - 1] == new_node
+                                     or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
+            self._nodes.insert(idx, new_node)
 
     def append(self, node):
         if isinstance(node, list) or isinstance(node, tuple):
@@ -816,3 +834,23 @@ class ConditionalFieldAccess(sp.Function):
 
     def __getnewargs__(self):
         return self.access, self.outofbounds_condition, self.outofbounds_value
+
+
+class NontemporalFence(Node):
+    def __init__(self):
+        super(NontemporalFence, self).__init__(parent=None)
+
+    @property
+    def symbols_defined(self):
+        return set()
+
+    @property
+    def undefined_symbols(self):
+        return set()
+
+    @property
+    def args(self):
+        return []
+
+    def __eq__(self, other):
+        return isinstance(other, NontemporalFence)
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 2a15ef74f..08fcede5e 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -277,6 +277,12 @@ class CBackend:
             else:
                 return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
 
+    def _print_NontemporalFence(self, _):
+        if 'stream_fence' in self._vector_instruction_set:
+            return self._vector_instruction_set['stream_fence'] + ';'
+        else:
+            return ''
+
     def _print_TemporaryMemoryAllocation(self, node):
         align = 64
         np_dtype = node.symbol.dtype.base_type.numpy_dtype
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index 836ffc579..0454621eb 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -164,4 +164,6 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
 
     result['+int'] = f"{pre}_add_{suffix['int']}({{0}}, {{1}})"
 
+    result['stream_fence'] = '_mm_mfence()'
+
     return result
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index a7d2b76d8..13d705b36 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -148,6 +148,11 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
                 if hasattr(indexed, 'field'):
                     nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
                 substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True)
+                if nontemporal:
+                    parent = loop_node.parent
+                    while type(parent.parent.parent) is not ast.KernelFunction:
+                        parent = parent.parent
+                    parent.parent.insert_after(ast.NontemporalFence(), parent, if_not_exists=True)
         if not successful:
             warnings.warn("Could not vectorize loop because of non-consecutive memory access")
             continue
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index 0889ab468..782ea28df 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -47,6 +47,8 @@ def test_aligned_and_nt_stores():
            'assume_inner_stride_one': True}
     update_rule = [ps.Assignment(f.center(), 0.25 * (g[-1, 0] + g[1, 0] + g[0, -1] + g[0, 1]))]
     ast = ps.create_kernel(update_rule, target=dh.default_target, cpu_vectorize_info=opt)
+    if 'stream_fence' in ast.instruction_set:
+        assert ast.instruction_set['stream_fence'] in ps.get_code_str(ast)
     kernel = ast.compile()
 
     dh.run_kernel(kernel)
-- 
GitLab