Commit e40ca9a1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Lint

parent b93d6992
......@@ -213,7 +213,11 @@ class CBackend:
def _print_SympyAssignment(self, node):
if node.is_declaration:
data_type = "const " + self._print(node.lhs.dtype) + " " if node.is_const else self._print(node.lhs.dtype) + " "
if node.is_const:
prefix = 'const '
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs))
else:
......
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
from pystencils.backends.cbackend import generate_c
from pystencils.astnodes import Node
import pystencils.data_types
from pystencils.astnodes import Node
from pystencils.backends.cbackend import generate_c
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as CUDA code.
......@@ -27,7 +28,6 @@ class OpenClBackend(CudaBackend):
super().__init__(sympy_printer, signature_only)
self._dialect = 'opencl'
def _print_Type(self, node):
code = super()._print_Type(node)
if isinstance(node, pystencils.data_types.PointerType):
......@@ -57,4 +57,3 @@ class OpenClSympyPrinter(CudaSympyPrinter):
dimension = self.DIMENSION_MAPPING[dimension]
function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
return f"{function_name}({dimension})"
......@@ -60,8 +60,8 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen
indexing = kernel_function_node.indexing
block_and_thread_numbers = indexing.call_parameters(shape)
block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
block_and_thread_numbers['grid'] = tuple(int(b*g) for (b, g) in zip(block_and_thread_numbers['block'],
block_and_thread_numbers['grid']))
block_and_thread_numbers['grid'] = tuple(int(b * g) for (b, g) in zip(block_and_thread_numbers['block'],
block_and_thread_numbers['grid']))
args = _build_numpy_argument_list(parameters, full_arguments)
args = [a.data for a in args if hasattr(a, 'data')]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment