opencl_backend.py 3.71 KB
Newer Older
1
2
from os.path import dirname, join

3
import pystencils.data_types
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
4
from pystencils.astnodes import Node
5
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
6
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
7
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
8

9
10
with open(join(dirname(__file__), 'opencl1.1_known_functions.txt')) as f:
    lines = f.readlines()
Markus Holzer's avatar
Markus Holzer committed
11
    OPENCL_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
12

Stephan Seitz's avatar
Stephan Seitz committed
13

Markus Holzer's avatar
Markus Holzer committed
14
def generate_opencl(ast_node: Node, signature_only: bool = False, custom_backend=None, with_globals=True) -> str:
Stephan Seitz's avatar
Stephan Seitz committed
15
    """Prints an abstract syntax tree node (made for target 'gpu') as OpenCL code.
Stephan Seitz's avatar
Stephan Seitz committed
16
17

    Args:
Markus Holzer's avatar
Markus Holzer committed
18
19
20
21
        ast_node: ast representation of kernel
        signature_only: generate signature without function body
        custom_backend: use own custom printer for code generation
        with_globals: enable usage of global variables
Stephan Seitz's avatar
Stephan Seitz committed
22
23

    Returns:
Markus Holzer's avatar
Markus Holzer committed
24
        OpenCL code for the ast node and its descendants
Stephan Seitz's avatar
Stephan Seitz committed
25
    """
Markus Holzer's avatar
Markus Holzer committed
26
27
    return generate_c(ast_node, signature_only, dialect='opencl',
                      custom_backend=custom_backend, with_globals=with_globals)
Stephan Seitz's avatar
Stephan Seitz committed
28
29


30
31
class OpenClBackend(CudaBackend):

Stephan Seitz's avatar
Stephan Seitz committed
32
33
    def __init__(self,
                 sympy_printer=None,
34
35
36
37
38
39
40
41
42
43
44
45
46
47
                 signature_only=False):
        if not sympy_printer:
            sympy_printer = OpenClSympyPrinter()

        super().__init__(sympy_printer, signature_only)
        self._dialect = 'opencl'

    def _print_Type(self, node):
        code = super()._print_Type(node)
        if isinstance(node, pystencils.data_types.PointerType):
            return "__global " + code
        else:
            return code

48
49
50
51
52
53
    def _print_ThreadBlockSynchronization(self, node):
        raise NotImplementedError()

    def _print_TextureDeclaration(self, node):
        raise NotImplementedError()

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

class OpenClSympyPrinter(CudaSympyPrinter):
    language = "OpenCL"

    DIMENSION_MAPPING = {
        'x': '0',
        'y': '1',
        'z': '2'
    }
    INDEXING_FUNCTION_MAPPING = {
        'blockIdx': 'get_group_id',
        'threadIdx': 'get_local_id',
        'blockDim': 'get_local_size',
        'gridDim': 'get_global_size'
    }

70
71
72
73
    def __init__(self):
        CustomSympyPrinter.__init__(self)
        self.known_functions = OPENCL_KNOWN_FUNCTIONS

74
75
76
77
78
79
80
    def _print_Type(self, node):
        code = super()._print_Type(node)
        if isinstance(node, pystencils.data_types.PointerType):
            return "__global " + code
        else:
            return code

81
82
83
84
85
    def _print_ThreadIndexingSymbol(self, node):
        symbol_name: str = node.name
        function_name, dimension = tuple(symbol_name.split("."))
        dimension = self.DIMENSION_MAPPING[dimension]
        function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
Markus Holzer's avatar
Markus Holzer committed
86
        return f"(int64_t) {function_name}({dimension})"
87
88
89
90

    def _print_TextureAccess(self, node):
        raise NotImplementedError()

91
92
93
    # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter
    # since built-in math functions are generic.
    # In CUDA, you have to differentiate between `sin` and `sinf`
94
95
96
97
    try:
        _print_math_func = CustomSympyPrinter._print_math_func
    except AttributeError:
        pass
98
99
100
101
102
103
    _print_Pow = CustomSympyPrinter._print_Pow

    def _print_Function(self, expr):
        if isinstance(expr, fast_division):
            return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args)
        elif isinstance(expr, fast_sqrt):
104
            return f"native_sqrt({tuple(self._print(a) for a in expr.args)})"
105
        elif isinstance(expr, fast_inv_sqrt):
106
            return f"native_rsqrt({tuple(self._print(a) for a in expr.args)})"
107
        return CustomSympyPrinter._print_Function(self, expr)