opencl_backend.py 3.79 KB
Newer Older
1
2
from os.path import dirname, join

3
import pystencils.data_types
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
4
from pystencils.astnodes import Node
5
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
6
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
Jan Hönig's avatar
Jan Hönig committed
7
from pystencils.enums import Backend
8
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
9

10
11
with open(join(dirname(__file__), 'opencl1.1_known_functions.txt')) as f:
    lines = f.readlines()
Markus Holzer's avatar
Markus Holzer committed
12
    OPENCL_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
13

Stephan Seitz's avatar
Stephan Seitz committed
14

Markus Holzer's avatar
Markus Holzer committed
15
def generate_opencl(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str:
Jan Hönig's avatar
Jan Hönig committed
16
    """Prints an abstract syntax tree node (made for `Target` 'GPU') as OpenCL code. # TODO Backend instead of Target?
Stephan Seitz's avatar
Stephan Seitz committed
17
18

    Args:
Markus Holzer's avatar
Markus Holzer committed
19
20
21
22
        ast_node: ast representation of kernel
        signature_only: generate signature without function body
        custom_backend: use own custom printer for code generation
        with_globals: enable usage of global variables
Stephan Seitz's avatar
Stephan Seitz committed
23
24

    Returns:
Markus Holzer's avatar
Markus Holzer committed
25
        OpenCL code for the ast node and its descendants
Stephan Seitz's avatar
Stephan Seitz committed
26
    """
Jan Hönig's avatar
Jan Hönig committed
27
    return generate_c(ast_node, signature_only, dialect=Backend.OPENCL,
Markus Holzer's avatar
Markus Holzer committed
28
                      custom_backend=custom_backend, with_globals=with_globals)
Stephan Seitz's avatar
Stephan Seitz committed
29
30


31
32
class OpenClBackend(CudaBackend):

Stephan Seitz's avatar
Stephan Seitz committed
33
34
    def __init__(self,
                 sympy_printer=None,
35
36
37
38
39
                 signature_only=False):
        if not sympy_printer:
            sympy_printer = OpenClSympyPrinter()

        super().__init__(sympy_printer, signature_only)
Jan Hönig's avatar
Jan Hönig committed
40
        self._dialect = Backend.OPENCL
41
42
43
44
45
46
47
48

    def _print_Type(self, node):
        code = super()._print_Type(node)
        if isinstance(node, pystencils.data_types.PointerType):
            return "__global " + code
        else:
            return code

49
50
51
52
53
54
    def _print_ThreadBlockSynchronization(self, node):
        raise NotImplementedError()

    def _print_TextureDeclaration(self, node):
        raise NotImplementedError()

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

class OpenClSympyPrinter(CudaSympyPrinter):
    language = "OpenCL"

    DIMENSION_MAPPING = {
        'x': '0',
        'y': '1',
        'z': '2'
    }
    INDEXING_FUNCTION_MAPPING = {
        'blockIdx': 'get_group_id',
        'threadIdx': 'get_local_id',
        'blockDim': 'get_local_size',
        'gridDim': 'get_global_size'
    }

71
72
73
74
    def __init__(self):
        CustomSympyPrinter.__init__(self)
        self.known_functions = OPENCL_KNOWN_FUNCTIONS

75
76
77
78
79
80
81
    def _print_Type(self, node):
        code = super()._print_Type(node)
        if isinstance(node, pystencils.data_types.PointerType):
            return "__global " + code
        else:
            return code

82
83
84
85
86
    def _print_ThreadIndexingSymbol(self, node):
        symbol_name: str = node.name
        function_name, dimension = tuple(symbol_name.split("."))
        dimension = self.DIMENSION_MAPPING[dimension]
        function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
Markus Holzer's avatar
Markus Holzer committed
87
        return f"(int64_t) {function_name}({dimension})"
88
89
90
91

    def _print_TextureAccess(self, node):
        raise NotImplementedError()

92
93
94
    # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter
    # since built-in math functions are generic.
    # In CUDA, you have to differentiate between `sin` and `sinf`
95
96
97
98
    try:
        _print_math_func = CustomSympyPrinter._print_math_func
    except AttributeError:
        pass
99
100
101
102
103
104
    _print_Pow = CustomSympyPrinter._print_Pow

    def _print_Function(self, expr):
        if isinstance(expr, fast_division):
            return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args)
        elif isinstance(expr, fast_sqrt):
105
            return f"native_sqrt({tuple(self._print(a) for a in expr.args)})"
106
        elif isinstance(expr, fast_inv_sqrt):
107
            return f"native_rsqrt({tuple(self._print(a) for a in expr.args)})"
108
        return CustomSympyPrinter._print_Function(self, expr)