cbackend.py 6.16 KB
Newer Older
1
from sympy.utilities.codegen import CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
2
from pystencils.astnodes import Node
3
from pystencils.types import createType, PointerType
4
5


6
def generateC(astNode):
Martin Bauer's avatar
Martin Bauer committed
7
8
9
    """
    Prints the abstract syntax tree as C function
    """
Martin Bauer's avatar
Martin Bauer committed
10
    fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
11
    useFloatConstants = createType("double") not in fieldTypes
Martin Bauer's avatar
Martin Bauer committed
12
    printer = CBackend(constantsAsFloats=useFloatConstants)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    return printer(astNode)


# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
    def __init__(self, code, symbolsRead, symbolsDefined):
        self._code = "\n" + code
        self._symbolsRead = set(symbolsRead)
        self._symbolsDefined = set(symbolsDefined)

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
    def symbolsDefined(self):
        return self._symbolsDefined

    @property
38
39
    def undefinedSymbols(self):
        return self.symbolsDefined - self._symbolsRead
40
41
42
43
44
45
46
47
48
49
50


class PrintNode(CustomCppCode):
    def __init__(self, symbolToPrint):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())


# ------------------------------------------- Printer ------------------------------------------------------------------


Michael Kuron's avatar
Michael Kuron committed
51
class CBackend(object):
52

Martin Bauer's avatar
Martin Bauer committed
53
    def __init__(self, constantsAsFloats=False, sympyPrinter=None):
54
        if sympyPrinter is None:
Martin Bauer's avatar
Martin Bauer committed
55
            self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
56
57
58
        else:
            self.sympyPrinter = sympyPrinter

59
        self._indent = "   "
60
61

    def __call__(self, node):
62
        return str(self._print(node))
63
64
65
66
67
68
69
70
71

    def _print(self, node):
        for cls in type(node).__mro__:
            methodName = "_print_" + cls.__name__
            if hasattr(self, methodName):
                return getattr(self, methodName)(node)
        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)

    def _print_KernelFunction(self, node):
72
        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
73
        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
74
75
        body = self._print(node.body)
        return funcDeclaration + "\n" + body
76
77

    def _print_Block(self, node):
78
        blockContents = "\n".join([self._print(child) for child in node.args])
Michael Kuron's avatar
Michael Kuron committed
79
        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)))
80
81

    def _print_PragmaBlock(self, node):
82
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
83
84
85

    def _print_LoopOverCoordinate(self, node):
        counterVar = node.loopCounterName
Martin Bauer's avatar
Martin Bauer committed
86
87
88
        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
89
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
90
91
92
93
94

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
95
96
97
98
99

    def _print_SympyAssignment(self, node):
        dtype = ""
        if node.isDeclaration:
            if node.isConst:
100
                dtype = "const " + str(node.lhs.dtype) + " "
101
            else:
102
103
                dtype = str(node.lhs.dtype) + " "
        return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
104
105

    def _print_TemporaryMemoryAllocation(self, node):
106
107
        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
108
109

    def _print_TemporaryMemoryFree(self, node):
110
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
111
112

    def _print_CustomCppCode(self, node):
113
        return node.code
114

115
116
117
118
119
120
121
122
123
    def _print_Conditional(self, node):
        conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
        trueBlock = self._print_Block(node.trueBlock)
        result = "if (%s) \n %s " % (conditionExpr, trueBlock)
        if node.falseBlock:
            falseBlock = self._print_Block(node.falseBlock)
            result += "else " + falseBlock
        return result

124
125
126
127
128

# ------------------------------------------ Helper function & classes -------------------------------------------------


class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
129
130
131
132
133

    def __init__(self, constantsAsFloats=False):
        self._constantsAsFloats = constantsAsFloats
        super(CustomSympyPrinter, self).__init__()

134
135
136
137
138
139
140
141
142
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
143
144
145
146
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
147
148
149
150
151
152
153
154

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
155
156
157
158
159
160
161
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
162

163
164
165
166
167
    def _print_Function(self, expr):
        name = str(expr.func).lower()
        if name == 'cast':
            arg, type = expr.args
            return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
168
        else:
169
            return super(CustomSympyPrinter, self)._print_Function(expr)