cuda_backend.py 4.21 KB
Newer Older
1
2
3
from os.path import dirname, join

from pystencils.astnodes import Node
Martin Bauer's avatar
Martin Bauer committed
4
5
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Stephan Seitz's avatar
Stephan Seitz committed
6
from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
7
8
9

with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
    lines = f.readlines()
Markus Holzer's avatar
Markus Holzer committed
10
    CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
11
12
13
14
15
16


def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
    """Prints an abstract syntax tree node as CUDA code.

    Args:
Martin Bauer's avatar
Martin Bauer committed
17
18
        astnode: KernelFunction node to generate code for
        signature_only: if True only the signature is printed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    Returns:
        C-like code for the ast node and its descendants
    """
    return generate_c(astnode, signature_only, dialect='cuda')


class CudaBackend(CBackend):

    def __init__(self, sympy_printer=None,
                 signature_only=False):
        if not sympy_printer:
            sympy_printer = CudaSympyPrinter()

        super().__init__(sympy_printer, signature_only, dialect='cuda')

    def _print_SharedMemoryAllocation(self, node):
36
37
38
39
40
        dtype = node.symbol.dtype
        name = self.sympy_printer.doprint(node.symbol.name)
        num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
        code = f"__shared__ {dtype} {name}[{num_elements}];"
        return code
41

Martin Bauer's avatar
Martin Bauer committed
42
43
    @staticmethod
    def _print_ThreadBlockSynchronization(node):
44
45
46
47
        code = "__synchtreads();"
        return code

    def _print_TextureDeclaration(self, node):
48

49
        # TODO: use fStrings here
50
51
52
53
54
55
56
57
58
59
60
61
        if node.texture.field.dtype.numpy_dtype.itemsize > 4:
            code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
                str(node.texture.field.dtype),
                node.texture.field.spatial_dimensions,
                node.texture
            )
        else:
            code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
                str(node.texture.field.dtype),
                node.texture.field.spatial_dimensions,
                node.texture
            )
62
63
64
65
66
67
68
        return code

    def _print_SkipIteration(self, _):
        return "return;"


class CudaSympyPrinter(CustomSympyPrinter):
69
    language = "CUDA"
70
71
72

    def __init__(self):
        super(CudaSympyPrinter, self).__init__()
73
        self.known_functions.update(CUDA_KNOWN_FUNCTIONS)
74

Stephan Seitz's avatar
Stephan Seitz committed
75
76
    def _print_InterpolatorAccess(self, node):
        dtype = node.interpolator.field.dtype.numpy_dtype
77

Stephan Seitz's avatar
Stephan Seitz committed
78
        if type(node) == DiffInterpolatorAccess:
79
            # cubicTex3D_1st_derivative_x(texture tex, float3 coord)
80
            template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)"  # noqa
Stephan Seitz's avatar
Stephan Seitz committed
81
        elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
82
            template = "cubicTex%iDSimple(%s, %s)"
83
        else:
84
85
86
87
88
89
            if dtype.itemsize > 4:
                # Use PyCuda hack!
                # https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp
                template = "fp_tex%iD(%s, %s)"
            else:
                template = "tex%iD(%s, %s)"
90
91

        code = template % (
Stephan Seitz's avatar
Stephan Seitz committed
92
93
            node.interpolator.field.spatial_dimensions,
            str(node.interpolator),
94
95
            # + 0.5 comes from Nvidia's staggered indexing
            ', '.join(self._print(o + 0.5) for o in reversed(node.offsets))
96
97
98
99
100
        )
        return code

    def _print_Function(self, expr):
        if isinstance(expr, fast_division):
101
102
            assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
            return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
103
        elif isinstance(expr, fast_sqrt):
104
105
            assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
            return f"__fsqrt_rn({self._print(expr.args[0])})"
106
        elif isinstance(expr, fast_inv_sqrt):
107
108
            assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
            return f"__frsqrt_rn({self._print(expr.args[0])})"
109
        return super()._print_Function(expr)