Commit e607e5bd authored by Martin Bauer's avatar Martin Bauer
Browse files

Formatting and style

parent c391dbb6
......@@ -22,18 +22,11 @@ try:
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
KERNCRAFT_NO_TERNARY_MODE = False
class UnsupportedCDialect(Exception):
def __init__(self):
super(UnsupportedCDialect, self).__init__()
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code.
......@@ -63,7 +56,7 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
from pystencils.backends.cuda_backend import CudaBackend
printer = CudaBackend(signature_only=signature_only)
else:
raise UnsupportedCDialect
raise ValueError("Unknown dialect: " + str(dialect))
code = printer(ast_node)
if not signature_only and isinstance(ast_node, KernelFunction):
code = "\n" + code
......
from os.path import dirname, join
from pystencils.astnodes import Node
from pystencils.backends.cbackend import (CBackend, CustomSympyPrinter,
generate_c)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
CUDA_KNOWN_FUNCTIONS = None
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines()
CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
......@@ -17,8 +13,8 @@ def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as CUDA code.
Args:
ast_node:
signature_only:
astnode: KernelFunction node to generate code for
signature_only: if True only the signature is printed
Returns:
C-like code for the ast node and its descendants
......@@ -41,7 +37,8 @@ class CudaBackend(CBackend):
name=self.sympy_printer.doprint(node.symbol.name),
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
def _print_ThreadBlockSynchronization(self, node):
@staticmethod
def _print_ThreadBlockSynchronization(node):
code = "__synchtreads();"
return code
......
......@@ -60,8 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path
from pystencils.utils import (atomic_file_write, file_handle_for_atomic_write,
recursive_dict_update)
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
def make_python_function(kernel_function_node, custom_backend=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment