From 31be359b704797b5e82863aecd56ed93005f1a04 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Wed, 31 Mar 2021 22:26:05 +0200
Subject: [PATCH] nontemporal cleanup

---
 pystencils/backends/arm_instruction_sets.py |  1 +
 pystencils/backends/cbackend.py             | 30 ++++++++++-----------
 pystencils/backends/ppc_instruction_sets.py |  1 +
 pystencils/backends/x86_instruction_sets.py |  3 ++-
 pystencils_tests/test_vectorization.py      |  6 ++---
 5 files changed, 21 insertions(+), 20 deletions(-)

diff --git a/pystencils/backends/arm_instruction_sets.py b/pystencils/backends/arm_instruction_sets.py
index c6da59dfd..4660ced01 100644
--- a/pystencils/backends/arm_instruction_sets.py
+++ b/pystencils/backends/arm_instruction_sets.py
@@ -47,6 +47,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
     suffix = f'q_f{bits[data_type]}'
 
     result = dict()
+    result['bytes'] = 16
 
     for intrinsic_id, function_shortcut in base_names.items():
         function_shortcut = function_shortcut.strip()
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index c14adca46..a924b67fd 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -1,5 +1,6 @@
 import re
 from collections import namedtuple
+import hashlib
 from typing import Set
 
 import numpy as np
@@ -285,9 +286,13 @@ class CBackend:
                         ptr, self.sympy_printer.doprint(rhs)) + ';'
                     code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
                 elif instr == 'stream' and 'streamAndFlushCacheline' in self._vector_instruction_set:
-                    code2 = self._vector_instruction_set['streamAndFlushCacheline'].format(
-                        ptr, self.sympy_printer.doprint(rhs), printed_mask) + ';'
-                    code = f"if ({flushcond}) {{\n\t{code}\n}} else {{\n\t{code2}\n}}"
+                    tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
+                    code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
+                        + self.sympy_printer.doprint(rhs) + ';'
+                    code1 = self._vector_instruction_set['stream'].format(ptr, tmpvar, printed_mask) + ';'
+                    code2 = self._vector_instruction_set['streamAndFlushCacheline'].format(ptr, tmpvar, printed_mask) \
+                        + ';'
+                    code += f"\nif ({flushcond}) {{\n\t{code1}\n}} else {{\n\t{code2}\n}}"
                 return pre_code + code
             else:
                 return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
@@ -302,18 +307,15 @@ class CBackend:
         if 'cachelineSize' in self._vector_instruction_set:
             code = f'const size_t {node.symbol} = {self._vector_instruction_set["cachelineSize"]};\n'
             code += f'const size_t {node.mask_symbol} = {node.symbol} - 1;\n'
-            code += f'const size_t {node.last_symbol} = {node.symbol} - 16;\n'  # TODO: determine size from instruction set
+            vectorsize = self._vector_instruction_set['bytes']
+            code += f'const size_t {node.last_symbol} = {node.symbol} - {vectorsize};\n'
             return code
         else:
             return ''
 
     def _print_TemporaryMemoryAllocation(self, node):
-        parent = node.parent
-        while parent.parent is not None:
-            parent = parent.parent
-        instruction_set = parent.instruction_set
-        if instruction_set:
-            align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
+        if self._vector_instruction_set:
+            align = self._vector_instruction_set['bytes']
         else:
             align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
 
@@ -336,12 +338,8 @@ class CBackend:
                            align=align)
 
     def _print_TemporaryMemoryFree(self, node):
-        parent = node.parent
-        while parent.parent is not None:
-            parent = parent.parent
-        instruction_set = parent.instruction_set
-        if instruction_set:
-            align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
+        if self._vector_instruction_set:
+            align = self._vector_instruction_set['bytes']
         else:
             align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
 
diff --git a/pystencils/backends/ppc_instruction_sets.py b/pystencils/backends/ppc_instruction_sets.py
index ff4209aac..b8116fd6a 100644
--- a/pystencils/backends/ppc_instruction_sets.py
+++ b/pystencils/backends/ppc_instruction_sets.py
@@ -65,6 +65,7 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'):
     intwidth = 128 // bits['int']
 
     result = dict()
+    result['bytes'] = 16
 
     for intrinsic_id, function_shortcut in base_names.items():
         function_shortcut = function_shortcut.strip()
diff --git a/pystencils/backends/x86_instruction_sets.py b/pystencils/backends/x86_instruction_sets.py
index 7809a8979..86196aad7 100644
--- a/pystencils/backends/x86_instruction_sets.py
+++ b/pystencils/backends/x86_instruction_sets.py
@@ -110,7 +110,8 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
 
     result = {
         'width': width[(data_type, instruction_set)],
-        'intwidth': width[('int', instruction_set)]
+        'intwidth': width[('int', instruction_set)],
+        'bytes': 4 * width[("float", instruction_set)]
     }
     pre = prefix[instruction_set]
     for intrinsic_id, function_shortcut in base_names.items():
diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py
index c7ffa2f3e..6852490bf 100644
--- a/pystencils_tests/test_vectorization.py
+++ b/pystencils_tests/test_vectorization.py
@@ -55,11 +55,11 @@ def test_aligned_and_nt_stores(openmp=False):
     if 'streamFence' in ast.instruction_set:
         assert ast.instruction_set['streamFence'] in ps.get_code_str(ast)
     if 'cachelineZero' in ast.instruction_set:
-        assert ast.instruction_set['cachelineZero'].split('{0}')[0] in ps.get_code_str(ast)
+        assert ast.instruction_set['cachelineZero'].split('{')[0] in ps.get_code_str(ast)
     if 'streamAndFlushCacheline' in ast.instruction_set:
-        assert ast.instruction_set['streamAndFlushCacheline'].split('{0}')[0] in ps.get_code_str(ast)
+        assert ast.instruction_set['streamAndFlushCacheline'].split('{')[0] in ps.get_code_str(ast)
     if 'flushCacheline' in ast.instruction_set:
-        assert ast.instruction_set['flushCacheline'].split('{0}')[0] in ps.get_code_str(ast)
+        assert ast.instruction_set['flushCacheline'].split('{')[0] in ps.get_code_str(ast)
     kernel = ast.compile()
 
     dh.run_kernel(kernel)
-- 
GitLab