cbackend.py 6.35 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
2
from sympy.utilities.codegen import CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
3
from pystencils.astnodes import Node
4
from pystencils.types import createType, PointerType
5
6


7
def generateC(astNode, signatureOnly=False):
Martin Bauer's avatar
Martin Bauer committed
8
9
10
    """
    Prints the abstract syntax tree as C function
    """
Martin Bauer's avatar
Martin Bauer committed
11
    fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
12
    useFloatConstants = createType("double") not in fieldTypes
13
    printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    return printer(astNode)


# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
    def __init__(self, code, symbolsRead, symbolsDefined):
        self._code = "\n" + code
        self._symbolsRead = set(symbolsRead)
        self._symbolsDefined = set(symbolsDefined)

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
    def symbolsDefined(self):
        return self._symbolsDefined

    @property
39
40
    def undefinedSymbols(self):
        return self.symbolsDefined - self._symbolsRead
41
42
43
44
45
46
47
48
49
50
51


class PrintNode(CustomCppCode):
    def __init__(self, symbolToPrint):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())


# ------------------------------------------- Printer ------------------------------------------------------------------


Michael Kuron's avatar
Michael Kuron committed
52
class CBackend(object):
53

54
    def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False):
55
        if sympyPrinter is None:
Martin Bauer's avatar
Martin Bauer committed
56
            self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
57
58
59
        else:
            self.sympyPrinter = sympyPrinter

60
        self._indent = "   "
61
        self._signatureOnly = signatureOnly
62
63

    def __call__(self, node):
64
        return str(self._print(node))
65
66
67
68
69
70
71
72
73

    def _print(self, node):
        for cls in type(node).__mro__:
            methodName = "_print_" + cls.__name__
            if hasattr(self, methodName):
                return getattr(self, methodName)(node)
        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)

    def _print_KernelFunction(self, node):
74
        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
75
        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
76
77
78
        if self._signatureOnly:
            return funcDeclaration

79
80
        body = self._print(node.body)
        return funcDeclaration + "\n" + body
81
82

    def _print_Block(self, node):
83
        blockContents = "\n".join([self._print(child) for child in node.args])
Michael Kuron's avatar
Michael Kuron committed
84
        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)))
85
86

    def _print_PragmaBlock(self, node):
87
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
88
89
90

    def _print_LoopOverCoordinate(self, node):
        counterVar = node.loopCounterName
Martin Bauer's avatar
Martin Bauer committed
91
92
93
        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
94
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
95
96
97
98
99

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
100
101
102
103
104

    def _print_SympyAssignment(self, node):
        dtype = ""
        if node.isDeclaration:
            if node.isConst:
105
                dtype = "const " + str(node.lhs.dtype) + " "
106
            else:
107
108
                dtype = str(node.lhs.dtype) + " "
        return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
109
110

    def _print_TemporaryMemoryAllocation(self, node):
111
112
        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
113
114

    def _print_TemporaryMemoryFree(self, node):
115
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
116
117

    def _print_CustomCppCode(self, node):
118
        return node.code
119

120
121
122
123
124
125
126
127
128
    def _print_Conditional(self, node):
        conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
        trueBlock = self._print_Block(node.trueBlock)
        result = "if (%s) \n %s " % (conditionExpr, trueBlock)
        if node.falseBlock:
            falseBlock = self._print_Block(node.falseBlock)
            result += "else " + falseBlock
        return result

129
130
131
132
133

# ------------------------------------------ Helper function & classes -------------------------------------------------


class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
134
135
136
137
138

    def __init__(self, constantsAsFloats=False):
        self._constantsAsFloats = constantsAsFloats
        super(CustomSympyPrinter, self).__init__()

139
140
141
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
Martin Bauer's avatar
Martin Bauer committed
142
            return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))
143
144
145
146
147
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
148
149
150
151
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
152
153
154
155
156
157
158
159

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
160
161
162
163
164
165
166
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
167

168
169
170
171
172
    def _print_Function(self, expr):
        name = str(expr.func).lower()
        if name == 'cast':
            arg, type = expr.args
            return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
173
        else:
174
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
175