Commit a4e8ce0c authored by Sebastian Bindgen's avatar Sebastian Bindgen
Browse files

Usage of custom boundary functor if given

parent e20f82a3
Pipeline #28607 passed with stage
in 5 minutes and 44 seconds
...@@ -271,7 +271,7 @@ class SerialDataHandling(DataHandling): ...@@ -271,7 +271,7 @@ class SerialDataHandling(DataHandling):
def synchronization_function_gpu(self, names, stencil_name=None, **_): def synchronization_function_gpu(self, names, stencil_name=None, **_):
return self.synchronization_function(names, stencil_name, target='gpu') return self.synchronization_function(names, stencil_name, target='gpu')
def synchronization_function(self, names, stencil=None, target=None, **_): def synchronization_function(self, names, stencil=None, target=None, functor=None, **_):
if target is None: if target is None:
target = self.default_target target = self.default_target
if target == 'opencl': if target == 'opencl':
...@@ -311,19 +311,22 @@ class SerialDataHandling(DataHandling): ...@@ -311,19 +311,22 @@ class SerialDataHandling(DataHandling):
if len(filtered_stencil) > 0: if len(filtered_stencil) > 0:
if target == 'cpu': if target == 'cpu':
from pystencils.slicing import get_periodic_boundary_functor if functor is None:
result.append(get_periodic_boundary_functor(filtered_stencil, ghost_layers=gls)) from pystencils.slicing import get_periodic_boundary_functor
functor = get_periodic_boundary_functor
result.append(functor(filtered_stencil, ghost_layers=gls))
else: else:
from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as boundary_func if functor is None:
target = 'gpu' if not isinstance(self.array_handler, PyOpenClArrayHandler) else 'opencl' from pystencils.gpucuda.periodicity import get_periodic_boundary_functor as functor
result.append(boundary_func(filtered_stencil, self._domainSize, target = 'gpu' if not isinstance(self.array_handler, PyOpenClArrayHandler) else 'opencl'
index_dimensions=self.fields[name].index_dimensions, result.append(functor(filtered_stencil, self._domainSize,
index_dim_shape=values_per_cell, index_dimensions=self.fields[name].index_dimensions,
dtype=self.fields[name].dtype.numpy_dtype, index_dim_shape=values_per_cell,
ghost_layers=gls, dtype=self.fields[name].dtype.numpy_dtype,
target=target, ghost_layers=gls,
opencl_queue=self._opencl_queue, target=target,
opencl_ctx=self._opencl_ctx)) opencl_queue=self._opencl_queue,
opencl_ctx=self._opencl_ctx))
if target == 'cpu': if target == 'cpu':
def result_functor(): def result_functor():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment