Skip to content
Snippets Groups Projects
Commit 5a88c472 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

improve TemporayMemoryAllocation: don't align more than necessary, behave correctly on MSVC

parent 9c784763
Branches
Tags
No related merge requests found
...@@ -296,14 +296,22 @@ class CBackend: ...@@ -296,14 +296,22 @@ class CBackend:
return '' return ''
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
align = 64 parent = node.parent
while parent.parent is not None:
parent = parent.parent
instruction_set = parent.instruction_set
if instruction_set:
align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
else:
align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
np_dtype = node.symbol.dtype.base_type.numpy_dtype np_dtype = node.symbol.dtype.base_type.numpy_dtype
required_size = np_dtype.itemsize * node.size + align required_size = np_dtype.itemsize * node.size + align
size = modulo_ceil(required_size, align) size = modulo_ceil(required_size, align)
code = "#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n" code = "#if defined(_MSC_VER)\n"
code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n"
code += "#elif defined(_MSC_VER)\n"
code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n" code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n"
code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n"
code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n"
code += "#else\n" code += "#else\n"
code += "{dtype} {name};\n" code += "{dtype} {name};\n"
code += "posix_memalign((void**) &{name}, {align}, {size});\n" code += "posix_memalign((void**) &{name}, {align}, {size});\n"
...@@ -316,7 +324,15 @@ class CBackend: ...@@ -316,7 +324,15 @@ class CBackend:
align=align) align=align)
def _print_TemporaryMemoryFree(self, node): def _print_TemporaryMemoryFree(self, node):
align = 64 parent = node.parent
while parent.parent is not None:
parent = parent.parent
instruction_set = parent.instruction_set
if instruction_set:
align = instruction_set['width'] * node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
else:
align = node.symbol.dtype.base_type.numpy_dtype.type(0).nbytes
code = "#if defined(_MSC_VER)\n" code = "#if defined(_MSC_VER)\n"
code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
code += "#else\n" code += "#else\n"
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment