cbackend.py 6.44 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
4
5
6
try:
    from sympy.utilities.codegen import CCodePrinter
except ImportError:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter

Martin Bauer's avatar
Martin Bauer committed
7
from pystencils.astnodes import Node
8
from pystencils.types import createType, PointerType
9
10


11
def generateC(astNode, signatureOnly=False):
Martin Bauer's avatar
Martin Bauer committed
12
13
14
    """
    Prints the abstract syntax tree as C function
    """
Martin Bauer's avatar
Martin Bauer committed
15
    fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
16
    useFloatConstants = createType("double") not in fieldTypes
17
    printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    return printer(astNode)


# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
    def __init__(self, code, symbolsRead, symbolsDefined):
        self._code = "\n" + code
        self._symbolsRead = set(symbolsRead)
        self._symbolsDefined = set(symbolsDefined)

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
    def symbolsDefined(self):
        return self._symbolsDefined

    @property
43
44
    def undefinedSymbols(self):
        return self.symbolsDefined - self._symbolsRead
45
46
47
48
49
50
51
52
53
54
55


class PrintNode(CustomCppCode):
    def __init__(self, symbolToPrint):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())


# ------------------------------------------- Printer ------------------------------------------------------------------


Michael Kuron's avatar
Michael Kuron committed
56
class CBackend(object):
57

58
    def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False):
59
        if sympyPrinter is None:
Martin Bauer's avatar
Martin Bauer committed
60
            self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
61
62
63
        else:
            self.sympyPrinter = sympyPrinter

64
        self._indent = "   "
65
        self._signatureOnly = signatureOnly
66
67

    def __call__(self, node):
68
        return str(self._print(node))
69
70
71
72
73
74
75
76
77

    def _print(self, node):
        for cls in type(node).__mro__:
            methodName = "_print_" + cls.__name__
            if hasattr(self, methodName):
                return getattr(self, methodName)(node)
        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)

    def _print_KernelFunction(self, node):
78
        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
79
        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
80
81
82
        if self._signatureOnly:
            return funcDeclaration

83
84
        body = self._print(node.body)
        return funcDeclaration + "\n" + body
85
86

    def _print_Block(self, node):
87
        blockContents = "\n".join([self._print(child) for child in node.args])
Michael Kuron's avatar
Michael Kuron committed
88
        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)))
89
90

    def _print_PragmaBlock(self, node):
91
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
92
93
94

    def _print_LoopOverCoordinate(self, node):
        counterVar = node.loopCounterName
Martin Bauer's avatar
Martin Bauer committed
95
96
97
        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
98
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
99
100
101
102
103

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
104
105
106
107
108

    def _print_SympyAssignment(self, node):
        dtype = ""
        if node.isDeclaration:
            if node.isConst:
109
                dtype = "const " + str(node.lhs.dtype) + " "
110
            else:
111
112
                dtype = str(node.lhs.dtype) + " "
        return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
113
114

    def _print_TemporaryMemoryAllocation(self, node):
115
116
        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
117
118

    def _print_TemporaryMemoryFree(self, node):
119
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
120
121

    def _print_CustomCppCode(self, node):
122
        return node.code
123

124
125
126
127
128
129
130
131
132
    def _print_Conditional(self, node):
        conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
        trueBlock = self._print_Block(node.trueBlock)
        result = "if (%s) \n %s " % (conditionExpr, trueBlock)
        if node.falseBlock:
            falseBlock = self._print_Block(node.falseBlock)
            result += "else " + falseBlock
        return result

133
134
135
136
137

# ------------------------------------------ Helper function & classes -------------------------------------------------


class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
138
139
140
141
142

    def __init__(self, constantsAsFloats=False):
        self._constantsAsFloats = constantsAsFloats
        super(CustomSympyPrinter, self).__init__()

143
144
145
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
Martin Bauer's avatar
Martin Bauer committed
146
            return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))
147
148
149
150
151
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
152
153
154
155
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
156
157
158
159
160
161
162
163

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
164
165
166
167
168
169
170
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
171

172
173
174
175
176
    def _print_Function(self, expr):
        name = str(expr.func).lower()
        if name == 'cast':
            arg, type = expr.args
            return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
177
        else:
178
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
179