From de710402ff0bd94ea13b9f2e3d143bad7d1f7502 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 14 Nov 2019 16:47:36 +0100 Subject: [PATCH] Allow default_target=='opencl' in SerialDataHandling --- pystencils/datahandling/serial_datahandling.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py index 931a7df30..5abc3cc2f 100644 --- a/pystencils/datahandling/serial_datahandling.py +++ b/pystencils/datahandling/serial_datahandling.py @@ -16,6 +16,9 @@ from pystencils.utils import DotDict class SerialDataHandling(DataHandling): + _GPU_LIKE_TARGETS = ['gpu', 'opencl'] + _GPU_LIKE_BACKENDS = ['gpucuda', 'opencl'] + def __init__(self, domain_size: Sequence[int], default_ghost_layers: int = 1, @@ -48,17 +51,16 @@ class SerialDataHandling(DataHandling): self._opencl_queue = opencl_queue self._opencl_ctx = opencl_ctx - if array_handler: - self.array_handler = array_handler - else: + if not array_handler: try: self.array_handler = PyCudaArrayHandler() except Exception: self.array_handler = None if default_target == 'opencl' or opencl_queue: - default_target = 'gpu' self.array_handler = PyOpenClArrayHandler(opencl_queue) + else: + self.array_handler = array_handler if periodicity is None or periodicity is False: periodicity = [False] * self.dim @@ -99,7 +101,7 @@ class SerialDataHandling(DataHandling): if layout is None: layout = self.default_layout if gpu is None: - gpu = self.default_target == 'gpu' + gpu = self.default_target in self._GPU_LIKE_TARGETS kwargs = { 'shape': tuple(s + 2 * ghost_layers for s in self._domainSize), @@ -239,13 +241,12 @@ class SerialDataHandling(DataHandling): self.to_gpu(name) def run_kernel(self, kernel_function, **kwargs): - arrays = self.gpu_arrays if kernel_function.ast.backend == 'gpucuda' \ - or kernel_function.ast.backend == 'opencl' else self.cpu_arrays + arrays = self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays kernel_function(**arrays, **kwargs) def get_kernel_kwargs(self, kernel_function, **kwargs): result = {} - result.update(self.gpu_arrays if kernel_function.ast.backend == 'gpucuda' else self.cpu_arrays) + result.update(self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays) result.update(kwargs) return [result] -- GitLab