cbackend.py 14.7 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
4
5
6
try:
    from sympy.utilities.codegen import CCodePrinter
except ImportError:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter

7
8
9
10
11
from collections import namedtuple
from sympy.core.mul import _keep_coeff
from sympy.core import S

from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
12
from pystencils.data_types import createType, PointerType, getTypeOfExpression, VectorType, castFunc
13
from pystencils.backends.simd_instruction_sets import selectedInstructionSet
14
15


16
def generateC(astNode, signatureOnly=False):
Martin Bauer's avatar
Martin Bauer committed
17
18
19
    """
    Prints the abstract syntax tree as C function
    """
Martin Bauer's avatar
Martin Bauer committed
20
    fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
21
    useFloatConstants = createType("double") not in fieldTypes
22
23
24

    vectorIS = selectedInstructionSet['double']
    printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS)
25
26
27
    return printer(astNode)


28
29
30
31
32
33
34
35
36
37
38
39
40
41
def getHeaders(astNode):
    headers = set()

    if hasattr(astNode, 'headers'):
        headers.update(astNode.headers)
    elif isinstance(astNode, SympyAssignment):
        if type(getTypeOfExpression(astNode.rhs)) is VectorType:
            headers.update(selectedInstructionSet['double']['headers'])

    for a in astNode.args:
        if isinstance(a, Node):
            headers.update(getHeaders(a))

    return headers
42
43


44
45
46
47
48
49
50
51
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
    def __init__(self, code, symbolsRead, symbolsDefined):
        self._code = "\n" + code
        self._symbolsRead = set(symbolsRead)
        self._symbolsDefined = set(symbolsDefined)
52
        self.headers = []
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
    def symbolsDefined(self):
        return self._symbolsDefined

    @property
67
68
    def undefinedSymbols(self):
        return self.symbolsDefined - self._symbolsRead
69
70
71
72
73
74


class PrintNode(CustomCppCode):
    def __init__(self, symbolToPrint):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
75
        self.headers.append("<iostream>")
76
77
78
79


# ------------------------------------------- Printer ------------------------------------------------------------------

80

Michael Kuron's avatar
Michael Kuron committed
81
class CBackend(object):
82

83
    def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
84
        if sympyPrinter is None:
Martin Bauer's avatar
Martin Bauer committed
85
            self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
86
87
88
89
            if vectorInstructionSet is not None:
                self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats)
            else:
                self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
90
91
92
        else:
            self.sympyPrinter = sympyPrinter

93
        self._vectorInstructionSet = vectorInstructionSet
94
        self._indent = "   "
95
        self._signatureOnly = signatureOnly
96
97

    def __call__(self, node):
98
99
100
101
102
        prevIs = VectorType.instructionSet
        VectorType.instructionSet = self._vectorInstructionSet
        result = str(self._print(node))
        VectorType.instructionSet = prevIs
        return result
103
104
105
106
107
108
109
110
111

    def _print(self, node):
        for cls in type(node).__mro__:
            methodName = "_print_" + cls.__name__
            if hasattr(self, methodName):
                return getattr(self, methodName)(node)
        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)

    def _print_KernelFunction(self, node):
112
        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
113
        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
114
115
116
        if self._signatureOnly:
            return funcDeclaration

117
118
        body = self._print(node.body)
        return funcDeclaration + "\n" + body
119
120

    def _print_Block(self, node):
121
        blockContents = "\n".join([self._print(child) for child in node.args])
Michael Kuron's avatar
Michael Kuron committed
122
        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)))
123
124

    def _print_PragmaBlock(self, node):
125
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
126
127
128

    def _print_LoopOverCoordinate(self, node):
        counterVar = node.loopCounterName
Martin Bauer's avatar
Martin Bauer committed
129
130
131
        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
132
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
133
134
135
136
137

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
138
139
140

    def _print_SympyAssignment(self, node):
        if node.isDeclaration:
141
142
143
144
145
146
147
            dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " "
            return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
        else:
            lhsType = getTypeOfExpression(node.lhs)
            if type(lhsType) is VectorType and node.lhs.func == castFunc:
                return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
                                                                   self.sympyPrinter.doprint(node.rhs)) + ';'
148
            else:
149
                return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
150
151

    def _print_TemporaryMemoryAllocation(self, node):
152
153
        return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
                                        node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size))
154
155

    def _print_TemporaryMemoryFree(self, node):
156
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
157
158

    def _print_CustomCppCode(self, node):
159
        return node.code
160

161
162
163
    def _print_Conditional(self, node):
        conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
        trueBlock = self._print_Block(node.trueBlock)
164
        result = "if (%s)\n%s " % (conditionExpr, trueBlock)
165
166
167
168
169
        if node.falseBlock:
            falseBlock = self._print_Block(node.falseBlock)
            result += "else " + falseBlock
        return result

170
171
172
173
174

# ------------------------------------------ Helper function & classes -------------------------------------------------


class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
175
176
177
178
179

    def __init__(self, constantsAsFloats=False):
        self._constantsAsFloats = constantsAsFloats
        super(CustomSympyPrinter, self).__init__()

180
181
182
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
183
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
184
185
186
187
188
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
189
190
191
192
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
193
194
195
196
197
198
199
200

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
201
202
203
204
205
206
207
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
208

209
    def _print_Function(self, expr):
210
        if expr.func == castFunc:
211
212
            arg, type = expr.args
            return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
213
        else:
214
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
215

216
217
218
219
220
221
222
223

class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

    def __init__(self, instructionSet, constantsAsFloats=False):
        super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
        self.instructionSet = instructionSet

224
225
226
227
228
229
230
231
    def _scalarFallback(self, funcName, expr, *args, **kwargs):
        exprType = getTypeOfExpression(expr)
        if type(exprType) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs)
        else:
            assert self.instructionSet['width'] == exprType.width
            return None

232
    def _print_Function(self, expr):
233
        if expr.func == castFunc:
234
235
236
237
238
239
240
241
242
            arg, dtype = expr.args
            if type(dtype) is VectorType:
                if type(arg) is ResolvedFieldAccess:
                    return self.instructionSet['loadU'].format("& " + self._print(arg))
                else:
                    return self.instructionSet['makeVec'].format(self._print(arg))

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

        argStrings = [self._print(a) for a in expr.args]
        assert len(argStrings) > 0
        result = argStrings[0]
        for item in argStrings[1:]:
            result = self.instructionSet['&'].format(result, item)
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

        argStrings = [self._print(a) for a in expr.args]
        assert len(argStrings) > 0
        result = argStrings[0]
        for item in argStrings[1:]:
            result = self.instructionSet['|'].format(result, item)
        return result

267
    def _print_Add(self, expr, order=None):
268
269
270
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
                sign, t = self._print_Mul(term, insideAdd=True)
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
            func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+']
            processed = func.format(processed, summand.term)
        return processed

293
    def _print_Pow(self, expr):
294
295
296
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
297
298
299
300
301
302
303
304
305
306
307
308

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
        else:
            if expr.exp == -1:
                one = self.instructionSet['makeVec'].format(1.0)
                return self.instructionSet['/'].format(one, self._print(expr.base))
            elif expr.exp == 0.5:
                return self.instructionSet['sqrt'].format(self._print(expr.base))
            else:
                raise ValueError("Generic exponential not supported")

309
    def _print_Mul(self, expr, insideAdd=False):
310
311
312
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]
335
        # a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
336
337
338
339
340
341

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
342
            result = self.instructionSet['*'].format(result, item)
343
344
345
346

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
347
348
                denominator_str = self.instructionSet['*'].format(denominator_str, item)
            result = self.instructionSet['/'].format(result, denominator_str)
349
350
351
352
353

        if insideAdd:
            return sign, result
        else:
            if sign < 0:
354
                return self.instructionSet['*'].format(self._print(S.NegativeOne), result)
355
356
357
            else:
                return result

358
    def _print_Relational(self, expr):
359
360
361
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
362
363
364
        return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))

    def _print_Equality(self, expr):
365
366
367
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
368
369
370
        return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))

    def _print_Piecewise(self, expr):
371
372
373
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
374
375
376
377
378
379
380
381
382
383
384
385
386
387

        if expr.args[-1].cond != True:
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
        for trueExpr, condition in reversed(expr.args[:-1]):
            result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
        return result