cuda_backend.py 3.92 KB
Newer Older
1
2
3
from os.path import dirname, join

from pystencils.astnodes import Node
Martin Bauer's avatar
Martin Bauer committed
4
5
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Stephan Seitz's avatar
Stephan Seitz committed
6
from pystencils.interpolation_astnodes import DiffInterpolatorAccess, InterpolationMode
7
8
9

with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
    lines = f.readlines()
Markus Holzer's avatar
Markus Holzer committed
10
    CUDA_KNOWN_FUNCTIONS = {l.strip(): l.strip() for l in lines if l}
11
12
13
14
15
16


def generate_cuda(astnode: Node, signature_only: bool = False) -> str:
    """Prints an abstract syntax tree node as CUDA code.

    Args:
Martin Bauer's avatar
Martin Bauer committed
17
18
        astnode: KernelFunction node to generate code for
        signature_only: if True only the signature is printed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    Returns:
        C-like code for the ast node and its descendants
    """
    return generate_c(astnode, signature_only, dialect='cuda')


class CudaBackend(CBackend):

    def __init__(self, sympy_printer=None,
                 signature_only=False):
        if not sympy_printer:
            sympy_printer = CudaSympyPrinter()

        super().__init__(sympy_printer, signature_only, dialect='cuda')

    def _print_SharedMemoryAllocation(self, node):
        code = "__shared__ {dtype} {name}[{num_elements}];"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
                           num_elements='*'.join([str(s) for s in node.shared_mem.shape]))

Martin Bauer's avatar
Martin Bauer committed
41
42
    @staticmethod
    def _print_ThreadBlockSynchronization(node):
43
44
45
46
        code = "__synchtreads();"
        return code

    def _print_TextureDeclaration(self, node):
47
48
49
50
51
52
53
54
55
56
57
58
59

        if node.texture.field.dtype.numpy_dtype.itemsize > 4:
            code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
                str(node.texture.field.dtype),
                node.texture.field.spatial_dimensions,
                node.texture
            )
        else:
            code = "texture<%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
                str(node.texture.field.dtype),
                node.texture.field.spatial_dimensions,
                node.texture
            )
60
61
62
63
64
65
66
        return code

    def _print_SkipIteration(self, _):
        return "return;"


class CudaSympyPrinter(CustomSympyPrinter):
67
    language = "CUDA"
68
69
70

    def __init__(self):
        super(CudaSympyPrinter, self).__init__()
71
        self.known_functions.update(CUDA_KNOWN_FUNCTIONS)
72

Stephan Seitz's avatar
Stephan Seitz committed
73
74
    def _print_InterpolatorAccess(self, node):
        dtype = node.interpolator.field.dtype.numpy_dtype
75

Stephan Seitz's avatar
Stephan Seitz committed
76
        if type(node) == DiffInterpolatorAccess:
77
            # cubicTex3D_1st_derivative_x(texture tex, float3 coord)
78
            template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)"  # noqa
Stephan Seitz's avatar
Stephan Seitz committed
79
        elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
80
            template = "cubicTex%iDSimple(%s, %s)"
81
        else:
82
83
84
85
86
87
            if dtype.itemsize > 4:
                # Use PyCuda hack!
                # https://github.com/inducer/pycuda/blob/master/pycuda/cuda/pycuda-helpers.hpp
                template = "fp_tex%iD(%s, %s)"
            else:
                template = "tex%iD(%s, %s)"
88
89

        code = template % (
Stephan Seitz's avatar
Stephan Seitz committed
90
91
            node.interpolator.field.spatial_dimensions,
            str(node.interpolator),
92
93
            # + 0.5 comes from Nvidia's staggered indexing
            ', '.join(self._print(o + 0.5) for o in reversed(node.offsets))
94
95
96
97
98
99
100
        )
        return code

    def _print_Function(self, expr):
        if isinstance(expr, fast_division):
            return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
        elif isinstance(expr, fast_sqrt):
101
            return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})"
102
        elif isinstance(expr, fast_inv_sqrt):
103
            return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})"
104
        return super()._print_Function(expr)