diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index b469b13f15a9e14032514d473b3ce2e07e037f5d..58dff94e05ef141aeae50db79e62c99f9bcddd33 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -296,14 +296,22 @@ class CBackend: return '' def _print_TemporaryMemoryAllocation(self, node): - align = 64 + parent = node.parent + while parent.parent is not None: + parent = parent.parent + instruction_set = parent.instruction_set + if instruction_set: + align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + else: + align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + np_dtype = node.symbol.dtype.base_type.numpy_dtype required_size = np_dtype.itemsize * node.size + align size = modulo_ceil(required_size, align) - code = "#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n" - code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n" - code += "#elif defined(_MSC_VER)\n" + code = "#if defined(_MSC_VER)\n" code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n" + code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n" + code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n" code += "#else\n" code += "{dtype} {name};\n" code += "posix_memalign((void**) &{name}, {align}, {size});\n" @@ -316,7 +324,15 @@ class CBackend: align=align) def _print_TemporaryMemoryFree(self, node): - align = 64 + parent = node.parent + while parent.parent is not None: + parent = parent.parent + instruction_set = parent.instruction_set + if instruction_set: + align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + else: + align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes + code = "#if defined(_MSC_VER)\n" code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) code += "#else\n"