cbackend.py 4.92 KB
Newer Older
1
import textwrap
2
3
4
5
from sympy.utilities.codegen import CCodePrinter
from pystencils.ast import Node


6
def generateC(astNode):
Martin Bauer's avatar
Martin Bauer committed
7
8
9
    """
    Prints the abstract syntax tree as C function
    """
10
11
12
13
    printer = CBackend(cuda=False)
    return printer(astNode)


14
def generateCUDA(astNode):
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    printer = CBackend(cuda=True)
    return printer(astNode)

# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
    def __init__(self, code, symbolsRead, symbolsDefined):
        self._code = "\n" + code
        self._symbolsRead = set(symbolsRead)
        self._symbolsDefined = set(symbolsDefined)

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
    def symbolsDefined(self):
        return self._symbolsDefined

    @property
40
41
    def undefinedSymbols(self):
        return self.symbolsDefined - self._symbolsRead
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


class PrintNode(CustomCppCode):
    def __init__(self, symbolToPrint):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())


# ------------------------------------------- Printer ------------------------------------------------------------------


class CBackend:

    def __init__(self, cuda=False):
        self.cuda = cuda
        self.sympyPrinter = CustomSympyPrinter()
58
        self._indent = "   "
59
60

    def __call__(self, node):
61
        return str(self._print(node))
62
63
64
65
66
67
68
69
70

    def _print(self, node):
        for cls in type(node).__mro__:
            methodName = "_print_" + cls.__name__
            if hasattr(self, methodName):
                return getattr(self, methodName)(node)
        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)

    def _print_KernelFunction(self, node):
71
        functionArguments = ["%s %s" % (s.dtype, s.name) for s in node.parameters]
72
        prefix = "__global__ void" if self.cuda else "void"
73
74
75
        funcDeclaration = "%s %s(%s)" % (prefix, node.functionName, ", ".join(functionArguments))
        body = self._print(node.body)
        return funcDeclaration + "\n" + body
76
77

    def _print_Block(self, node):
78
79
        blockContents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (textwrap.indent(blockContents, self._indent))
80
81

    def _print_PragmaBlock(self, node):
82
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
83
84
85

    def _print_LoopOverCoordinate(self, node):
        counterVar = node.loopCounterName
Martin Bauer's avatar
Martin Bauer committed
86
87
88
        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
89
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
90
91
92
93
94

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
95
96
97
98
99
100
101
102

    def _print_SympyAssignment(self, node):
        dtype = ""
        if node.isDeclaration:
            if node.isConst:
                dtype = "const " + node.lhs.dtype + " "
            else:
                dtype = node.lhs.dtype + " "
103
        return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
104
105

    def _print_TemporaryMemoryAllocation(self, node):
106
107
        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
108
109

    def _print_TemporaryMemoryFree(self, node):
110
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
111
112

    def _print_CustomCppCode(self, node):
113
        return node.code
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


# ------------------------------------------ Helper function & classes -------------------------------------------------


class CustomSympyPrinter(CCodePrinter):
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
        return str(expr.evalf().num)

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
138
        return result.replace("\n", "")