diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 495c0a059c77e6723ca340358aeaf71f74e06056..38cba5c354d6ef341fc760b7c2af3060c945a527 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -213,7 +213,11 @@ class CBackend: def _print_SympyAssignment(self, node): if node.is_declaration: - data_type = "const " + self._print(node.lhs.dtype) + " " if node.is_const else self._print(node.lhs.dtype) + " " + if node.is_const: + prefix = 'const ' + else: + prefix = '' + data_type = prefix + self._print(node.lhs.dtype) + " " return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: diff --git a/pystencils/backends/opencl_backend.py b/pystencils/backends/opencl_backend.py index d44c944ddfe86329b5e31a554ff91ab57cce1568..ac8a7a78fe8aafcd06159c2496b867efab071dcf 100644 --- a/pystencils/backends/opencl_backend.py +++ b/pystencils/backends/opencl_backend.py @@ -1,7 +1,8 @@ -from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter -from pystencils.backends.cbackend import generate_c -from pystencils.astnodes import Node import pystencils.data_types +from pystencils.astnodes import Node +from pystencils.backends.cbackend import generate_c +from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter + def generate_opencl(astnode: Node, signature_only: bool = False) -> str: """Prints an abstract syntax tree node as CUDA code. @@ -27,7 +28,6 @@ class OpenClBackend(CudaBackend): super().__init__(sympy_printer, signature_only) self._dialect = 'opencl' - def _print_Type(self, node): code = super()._print_Type(node) if isinstance(node, pystencils.data_types.PointerType): @@ -57,4 +57,3 @@ class OpenClSympyPrinter(CudaSympyPrinter): dimension = self.DIMENSION_MAPPING[dimension] function_name = self.INDEXING_FUNCTION_MAPPING[function_name] return f"{function_name}({dimension})" - diff --git a/pystencils/opencl/opencljit.py b/pystencils/opencl/opencljit.py index 7f4bdb659f07179e13c04d404a50a6fe17b19a18..dd0660667ba712c285bf13ca7d6fa440ff050bbf 100644 --- a/pystencils/opencl/opencljit.py +++ b/pystencils/opencl/opencljit.py @@ -60,8 +60,8 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen indexing = kernel_function_node.indexing block_and_thread_numbers = indexing.call_parameters(shape) block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block']) - block_and_thread_numbers['grid'] = tuple(int(b*g) for (b, g) in zip(block_and_thread_numbers['block'], - block_and_thread_numbers['grid'])) + block_and_thread_numbers['grid'] = tuple(int(b * g) for (b, g) in zip(block_and_thread_numbers['block'], + block_and_thread_numbers['grid'])) args = _build_numpy_argument_list(parameters, full_arguments) args = [a.data for a in args if hasattr(a, 'data')]