From 5a88c4729f6c82f8367a962819d35fe3bead2014 Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Wed, 31 Mar 2021 21:47:50 +0200 Subject: [PATCH] improve TemporayMemoryAllocation: don't align more than necessary, behave correctly on MSVC --- pystencils/backends/cbackend.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index b469b13f1..58dff94e0 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -296,14 +296,22 @@ class CBackend: return '' def _print_TemporaryMemoryAllocation(self, node): - align = 64 + parent = node.parent + while parent.parent is not None: + parent = parent.parent + instruction_set = parent.instruction_set + if instruction_set: + align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + else: + align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + np_dtype = node.symbol.dtype.base_type.numpy_dtype required_size = np_dtype.itemsize * node.size + align size = modulo_ceil(required_size, align) - code = "#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n" - code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n" - code += "#elif defined(_MSC_VER)\n" + code = "#if defined(_MSC_VER)\n" code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n" + code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n" + code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n" code += "#else\n" code += "{dtype} {name};\n" code += "posix_memalign((void**) &{name}, {align}, {size});\n" @@ -316,7 +324,15 @@ class CBackend: align=align) def _print_TemporaryMemoryFree(self, node): - align = 64 + parent = node.parent + while parent.parent is not None: + parent = parent.parent + instruction_set = parent.instruction_set + if instruction_set: + align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + else: + align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + code = "#if defined(_MSC_VER)\n" code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) code += "#else\n" -- GitLab