From 7e6d2a21754b5026e5abca2016e83daabe6cb942 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Wed, 31 Mar 2021 21:50:46 +0200
Subject: [PATCH] nontemporal stores: flush cacheline if available

---
 pystencils/astnodes.py                      |  4 +++-
 pystencils/backends/cbackend.py             | 18 +++++++++++++++---
 pystencils/backends/ppc_instruction_sets.py |  3 ++-
 pystencils_tests/test_vectorization.py      |  4 ++++
 4 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index a1f282b9d..537316457 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -859,14 +859,16 @@ class NontemporalFence(Node):
 
 
 class CachelineSize(Node):
+    symbol = sp.Symbol("_clsize")
     mask_symbol = sp.Symbol("_clsize_mask")
+    last_symbol = sp.Symbol("_cl_lastvec")
     
     def __init__(self):
         super(CachelineSize, self).__init__(parent=None)
 
     @property
     def symbols_defined(self):
-        return set([self.mask_symbol])
+        return set([self.symbol, self.mask_symbol, self.last_symbol])
 
     @property
     def undefined_symbols(self):
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 58dff94e0..c14adca46 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -274,11 +274,20 @@ class CBackend:
                 ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
                 pre_code = ''
                 if instr == 'stream' and 'cachelineZero' in self._vector_instruction_set:
-                    pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "\n\t" + \
-                        self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n'
+                    pre_code = f"if (((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0) " + "{\n\t" + \
+                        self._vector_instruction_set['cachelineZero'].format(ptr) + ';\n}\n'
 
                 code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
                                                                   printed_mask) + ';'
+                flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) != {CachelineSize.last_symbol}"
+                if instr == 'stream' and 'flushCacheline' in self._vector_instruction_set:
+                    code2 = self._vector_instruction_set['flushCacheline'].format(
+                        ptr, self.sympy_printer.doprint(rhs)) + ';'
+                    code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
+                elif instr == 'stream' and 'streamAndFlushCacheline' in self._vector_instruction_set:
+                    code2 = self._vector_instruction_set['streamAndFlushCacheline'].format(
+                        ptr, self.sympy_printer.doprint(rhs), printed_mask) + ';'
+                    code = f"if ({flushcond}) {{\n\t{code}\n}} else {{\n\t{code2}\n}}"
                 return pre_code + code
             else:
                 return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
@@ -291,7 +300,10 @@ class CBackend:
 
     def _print_CachelineSize(self, node):
         if 'cachelineSize' in self._vector_instruction_set:
-            return f'const size_t {node.mask_symbol} = {self._vector_instruction_set["cachelineSize"]} - 1;'
+            code = f'const size_t {node.symbol} = {self._vector_instruction_set["cachelineSize"]};\n'
+            code += f'const size_t {node.mask_symbol} = {node.symbol} - 1;\n'
+            code += f'const size_t {node.last_symbol} = {node.symbol} - 16;\n'  # TODO: determine size from instruction set
+            return code
         else:
             return ''
 
diff --git a/pystencils/backends/ppc_instruction_sets.py b/pystencils/backends/ppc_instruction_sets.py
index a1c481ae4..ff4209aac 100644
--- a/pystencils/backends/ppc_instruction_sets.py
+++ b/pystencils/backends/ppc_instruction_sets.py
@@ -29,7 +29,8 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'):
         'loadA': 'ld[0x0, 0]',
         'storeU': 'xst[1, 0x0, 0]',
         'storeA': 'st[1, 0x0, 0]',
-        'stream': 'st[1, 0x0, 0]',  # stl would flush the cacheline, which only makes sense for the last item
+        'stream': 'st[1, 0x0, 0]',
+        'streamAndFlushCacheline': 'stl[1, 0x0, 0]',
 
         'abs': 'abs[0]',
         '==': 'cmpeq[0, 1]',
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index d05c37c6b..c7ffa2f3e 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -56,6 +56,10 @@ def test_aligned_and_nt_stores(openmp=False):
         assert ast.instruction_set['streamFence'] in ps.get_code_str(ast)
     if 'cachelineZero' in ast.instruction_set:
         assert ast.instruction_set['cachelineZero'].split('{0}')[0] in ps.get_code_str(ast)
+    if 'streamAndFlushCacheline' in ast.instruction_set:
+        assert ast.instruction_set['streamAndFlushCacheline'].split('{0}')[0] in ps.get_code_str(ast)
+    if 'flushCacheline' in ast.instruction_set:
+        assert ast.instruction_set['flushCacheline'].split('{0}')[0] in ps.get_code_str(ast)
     kernel = ast.compile()
 
     dh.run_kernel(kernel)
-- 
GitLab