From 5a88c4729f6c82f8367a962819d35fe3bead2014 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Wed, 31 Mar 2021 21:47:50 +0200
Subject: [PATCH] improve TemporayMemoryAllocation: don't align more than
 necessary, behave correctly on MSVC

---
 pystencils/backends/cbackend.py | 26 +++++++++++++++++++++-----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index b469b13f1..58dff94e0 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -296,14 +296,22 @@ class CBackend:
             return ''
 
     def _print_TemporaryMemoryAllocation(self, node):
-        align = 64
+        parent = node.parent
+        while parent.parent is not None:
+            parent = parent.parent
+        instruction_set = parent.instruction_set
+        if instruction_set:
+            align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
+        else:
+            align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
+
         np_dtype = node.symbol.dtype.base_type.numpy_dtype
         required_size = np_dtype.itemsize * node.size + align
         size = modulo_ceil(required_size, align)
-        code = "#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n"
-        code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n"
-        code += "#elif defined(_MSC_VER)\n"
+        code = "#if defined(_MSC_VER)\n"
         code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n"
+        code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n"
+        code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n"
         code += "#else\n"
         code += "{dtype} {name};\n"
         code += "posix_memalign((void**) &{name}, {align}, {size});\n"
@@ -316,7 +324,15 @@ class CBackend:
                            align=align)
 
     def _print_TemporaryMemoryFree(self, node):
-        align = 64
+        parent = node.parent
+        while parent.parent is not None:
+            parent = parent.parent
+        instruction_set = parent.instruction_set
+        if instruction_set:
+            align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
+        else:
+            align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
+
         code = "#if defined(_MSC_VER)\n"
         code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
         code += "#else\n"
-- 
GitLab