opencl_backend.py 3.21 KB
Newer Older
1
2
from os.path import dirname, join

3
import pystencils.data_types
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
4
from pystencils.astnodes import Node
5
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
6
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
7
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
8

9
10
11
12
with open(join(dirname(__file__), 'opencl1.1_known_functions.txt')) as f:
    lines = f.readlines()
    OPENCL_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}

Stephan Seitz's avatar
Stephan Seitz committed
13
14

def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
Stephan Seitz's avatar
Stephan Seitz committed
15
    """Prints an abstract syntax tree node (made for target 'gpu') as OpenCL code.
Stephan Seitz's avatar
Stephan Seitz committed
16
17
18
19
20
21
22
23
24
25
26

    Args:
        astnode: KernelFunction node to generate code for
        signature_only: if True only the signature is printed

    Returns:
        C-like code for the ast node and its descendants
    """
    return generate_c(astnode, signature_only, dialect='opencl')


27
28
class OpenClBackend(CudaBackend):

Stephan Seitz's avatar
Stephan Seitz committed
29
30
    def __init__(self,
                 sympy_printer=None,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
                 signature_only=False):
        if not sympy_printer:
            sympy_printer = OpenClSympyPrinter()

        super().__init__(sympy_printer, signature_only)
        self._dialect = 'opencl'

    def _print_Type(self, node):
        code = super()._print_Type(node)
        if isinstance(node, pystencils.data_types.PointerType):
            return "__global " + code
        else:
            return code

45
46
47
48
49
50
    def _print_ThreadBlockSynchronization(self, node):
        raise NotImplementedError()

    def _print_TextureDeclaration(self, node):
        raise NotImplementedError()

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

class OpenClSympyPrinter(CudaSympyPrinter):
    language = "OpenCL"

    DIMENSION_MAPPING = {
        'x': '0',
        'y': '1',
        'z': '2'
    }
    INDEXING_FUNCTION_MAPPING = {
        'blockIdx': 'get_group_id',
        'threadIdx': 'get_local_id',
        'blockDim': 'get_local_size',
        'gridDim': 'get_global_size'
    }

67
68
69
70
    def __init__(self):
        CustomSympyPrinter.__init__(self)
        self.known_functions = OPENCL_KNOWN_FUNCTIONS

71
72
73
74
75
76
    def _print_ThreadIndexingSymbol(self, node):
        symbol_name: str = node.name
        function_name, dimension = tuple(symbol_name.split("."))
        dimension = self.DIMENSION_MAPPING[dimension]
        function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
        return f"{function_name}({dimension})"
77
78
79
80

    def _print_TextureAccess(self, node):
        raise NotImplementedError()

81
82
83
84
85
86
87
88
89
90
91
92
93
94
    # For math functions, OpenCL is more similar to the C++ printer CustomSympyPrinter
    # since built-in math functions are generic.
    # In CUDA, you have to differentiate between `sin` and `sinf`
    _print_math_func = CustomSympyPrinter._print_math_func
    _print_Pow = CustomSympyPrinter._print_Pow

    def _print_Function(self, expr):
        if isinstance(expr, fast_division):
            return "native_divide(%s, %s)" % tuple(self._print(a) for a in expr.args)
        elif isinstance(expr, fast_sqrt):
            return "native_sqrt(%s)" % tuple(self._print(a) for a in expr.args)
        elif isinstance(expr, fast_inv_sqrt):
            return "native_rsqrt(%s)" % tuple(self._print(a) for a in expr.args)
        return CustomSympyPrinter._print_Function(self, expr)