Skip to content
Snippets Groups Projects
Commit 27e3929f authored by Martin Bauer's avatar Martin Bauer
Browse files

Use optimized walberla pack info

parent d01eef7d
No related merge requests found
...@@ -176,7 +176,7 @@ class CBackend: ...@@ -176,7 +176,7 @@ class CBackend:
return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
align = 128 align = 64
np_dtype = node.symbol.dtype.base_type.numpy_dtype np_dtype = node.symbol.dtype.base_type.numpy_dtype
required_size = np_dtype.itemsize * node.size + align required_size = np_dtype.itemsize * node.size + align
size = modulo_ceil(required_size, align) size = modulo_ceil(required_size, align)
...@@ -188,7 +188,7 @@ class CBackend: ...@@ -188,7 +188,7 @@ class CBackend:
align=align) align=align)
def _print_TemporaryMemoryFree(self, node): def _print_TemporaryMemoryFree(self, node):
align = 128 align = 64
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
@staticmethod @staticmethod
......
...@@ -274,13 +274,13 @@ class ParallelDataHandling(DataHandling): ...@@ -274,13 +274,13 @@ class ParallelDataHandling(DataHandling):
for name in self._custom_data_transfer_functions.keys(): for name in self._custom_data_transfer_functions.keys():
self.to_gpu(name) self.to_gpu(name)
def synchronization_function_cpu(self, names, stencil=None, buffered=True, **_): def synchronization_function_cpu(self, names, stencil=None, buffered=True, stencil_restricted=False, **_):
return self.synchronization_function(names, stencil, 'cpu', buffered, ) return self.synchronization_function(names, stencil, 'cpu', buffered, stencil_restricted)
def synchronization_function_gpu(self, names, stencil=None, buffered=True, **_): def synchronization_function_gpu(self, names, stencil=None, buffered=True, stencil_restricted=False, **_):
return self.synchronization_function(names, stencil, 'gpu', buffered) return self.synchronization_function(names, stencil, 'gpu', buffered, stencil_restricted)
def synchronization_function(self, names, stencil=None, target='cpu', buffered=True): def synchronization_function(self, names, stencil=None, target='cpu', buffered=True, stencil_restricted=False):
if target is None: if target is None:
target = self.default_target target = self.default_target
...@@ -293,6 +293,8 @@ class ParallelDataHandling(DataHandling): ...@@ -293,6 +293,8 @@ class ParallelDataHandling(DataHandling):
create_scheme = wlb.createUniformBufferedScheme if buffered else wlb.createUniformDirectScheme create_scheme = wlb.createUniformBufferedScheme if buffered else wlb.createUniformDirectScheme
if target == 'cpu': if target == 'cpu':
create_packing = wlb.field.createPackInfo if buffered else wlb.field.createMPIDatatypeInfo create_packing = wlb.field.createPackInfo if buffered else wlb.field.createMPIDatatypeInfo
if not buffered and stencil_restricted:
create_packing = wlb.field.createStencilRestrictedPackInfo
else: else:
assert target == 'gpu' assert target == 'gpu'
create_packing = wlb.cuda.createPackInfo if buffered else wlb.cuda.createMPIDatatypeInfo create_packing = wlb.cuda.createPackInfo if buffered else wlb.cuda.createMPIDatatypeInfo
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment