diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py index ce4629f6ab7f35c96cec437ab76b778dfec83e04..9e18acf4a240932bcdb0aad80a899cbe6d411a3f 100644 --- a/pystencils/datahandling/serial_datahandling.py +++ b/pystencils/datahandling/serial_datahandling.py @@ -271,7 +271,7 @@ class SerialDataHandling(DataHandling): def synchronization_function_gpu(self, names, stencil_name=None, **_): return self.synchronization_function(names, stencil_name, target='gpu') - def synchronization_function(self, names, stencil=None, target=None, **_): + def synchronization_function(self, names, stencil=None, target=None, functor=None, **_): if target is None: target = self.default_target if target == 'opencl': @@ -311,19 +311,22 @@ class SerialDataHandling(DataHandling): if len(filtered_stencil) > 0: if target == 'cpu': - from pystencils.slicing import get_periodic_boundary_functor - result.append(get_periodic_boundary_functor(filtered_stencil, ghost_layers=gls)) + if functor is None: + from pystencils.slicing import get_periodic_boundary_functor + functor = get_periodic_boundary_functor + result.append(functor(filtered_stencil, ghost_layers=gls)) else: - from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as boundary_func - target = 'gpu' if not isinstance(self.array_handler, PyOpenClArrayHandler) else 'opencl' - result.append(boundary_func(filtered_stencil, self._domainSize, - index_dimensions=self.fields[name].index_dimensions, - index_dim_shape=values_per_cell, - dtype=self.fields[name].dtype.numpy_dtype, - ghost_layers=gls, - target=target, - opencl_queue=self._opencl_queue, - opencl_ctx=self._opencl_ctx)) + if functor is None: + from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as functor + target = 'gpu' if not isinstance(self.array_handler, PyOpenClArrayHandler) else 'opencl' + result.append(functor(filtered_stencil, self._domainSize, + index_dimensions=self.fields[name].index_dimensions, + index_dim_shape=values_per_cell, + dtype=self.fields[name].dtype.numpy_dtype, + ghost_layers=gls, + target=target, + opencl_queue=self._opencl_queue, + opencl_ctx=self._opencl_ctx)) if target == 'cpu': def result_functor():