opencl_backend.py 2.11 KB
Newer Older
1
import pystencils.data_types
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
2
from pystencils.astnodes import Node
3
from pystencils.backends.cbackend import CustomSympyPrinter, generate_c
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
4
5
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter

Stephan Seitz's avatar
Stephan Seitz committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19

def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
    """Prints an abstract syntax tree node as CUDA code.

    Args:
        astnode: KernelFunction node to generate code for
        signature_only: if True only the signature is printed

    Returns:
        C-like code for the ast node and its descendants
    """
    return generate_c(astnode, signature_only, dialect='opencl')


20
21
class OpenClBackend(CudaBackend):

Stephan Seitz's avatar
Stephan Seitz committed
22
23
    def __init__(self,
                 sympy_printer=None,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
                 signature_only=False):
        if not sympy_printer:
            sympy_printer = OpenClSympyPrinter()

        super().__init__(sympy_printer, signature_only)
        self._dialect = 'opencl'

    def _print_Type(self, node):
        code = super()._print_Type(node)
        if isinstance(node, pystencils.data_types.PointerType):
            return "__global " + code
        else:
            return code

38
39
40
41
42
43
    def _print_ThreadBlockSynchronization(self, node):
        raise NotImplementedError()

    def _print_TextureDeclaration(self, node):
        raise NotImplementedError()

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

class OpenClSympyPrinter(CudaSympyPrinter):
    language = "OpenCL"

    DIMENSION_MAPPING = {
        'x': '0',
        'y': '1',
        'z': '2'
    }
    INDEXING_FUNCTION_MAPPING = {
        'blockIdx': 'get_group_id',
        'threadIdx': 'get_local_id',
        'blockDim': 'get_local_size',
        'gridDim': 'get_global_size'
    }

    def _print_ThreadIndexingSymbol(self, node):
        symbol_name: str = node.name
        function_name, dimension = tuple(symbol_name.split("."))
        dimension = self.DIMENSION_MAPPING[dimension]
        function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
        return f"{function_name}({dimension})"
66
67
68
69
70
71

    def _print_TextureAccess(self, node):
        raise NotImplementedError()

    # Avoid usage of CUDA intrinsics
    _print_Function = CustomSympyPrinter._print_Function