cbackend.py 15.1 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
4

from pystencils.bitoperations import xor, rightShift, leftShift

Martin Bauer's avatar
Martin Bauer committed
5
6
7
8
9
try:
    from sympy.utilities.codegen import CCodePrinter
except ImportError:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter

10
11
12
13
14
from collections import namedtuple
from sympy.core.mul import _keep_coeff
from sympy.core import S

from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
15
from pystencils.data_types import createType, PointerType, getTypeOfExpression, VectorType, castFunc
16
from pystencils.backends.simd_instruction_sets import selectedInstructionSet
17
18


19
def generateC(astNode, signatureOnly=False):
Martin Bauer's avatar
Martin Bauer committed
20
21
22
    """
    Prints the abstract syntax tree as C function
    """
Martin Bauer's avatar
Martin Bauer committed
23
    fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
24
    useFloatConstants = createType("double") not in fieldTypes
25
26
27

    vectorIS = selectedInstructionSet['double']
    printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS)
28
29
30
    return printer(astNode)


31
32
33
34
35
36
37
38
39
40
41
42
43
44
def getHeaders(astNode):
    headers = set()

    if hasattr(astNode, 'headers'):
        headers.update(astNode.headers)
    elif isinstance(astNode, SympyAssignment):
        if type(getTypeOfExpression(astNode.rhs)) is VectorType:
            headers.update(selectedInstructionSet['double']['headers'])

    for a in astNode.args:
        if isinstance(a, Node):
            headers.update(getHeaders(a))

    return headers
45
46


47
48
49
50
51
52
53
54
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
    def __init__(self, code, symbolsRead, symbolsDefined):
        self._code = "\n" + code
        self._symbolsRead = set(symbolsRead)
        self._symbolsDefined = set(symbolsDefined)
55
        self.headers = []
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
    def symbolsDefined(self):
        return self._symbolsDefined

    @property
70
71
    def undefinedSymbols(self):
        return self.symbolsDefined - self._symbolsRead
72
73
74
75
76
77


class PrintNode(CustomCppCode):
    def __init__(self, symbolToPrint):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
        super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
78
        self.headers.append("<iostream>")
79
80
81
82


# ------------------------------------------- Printer ------------------------------------------------------------------

83

Michael Kuron's avatar
Michael Kuron committed
84
class CBackend(object):
85

86
    def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
87
        if sympyPrinter is None:
Martin Bauer's avatar
Martin Bauer committed
88
            self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
89
90
91
92
            if vectorInstructionSet is not None:
                self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats)
            else:
                self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
93
94
95
        else:
            self.sympyPrinter = sympyPrinter

96
        self._vectorInstructionSet = vectorInstructionSet
97
        self._indent = "   "
98
        self._signatureOnly = signatureOnly
99
100

    def __call__(self, node):
101
102
103
104
105
        prevIs = VectorType.instructionSet
        VectorType.instructionSet = self._vectorInstructionSet
        result = str(self._print(node))
        VectorType.instructionSet = prevIs
        return result
106
107
108
109
110
111
112
113
114

    def _print(self, node):
        for cls in type(node).__mro__:
            methodName = "_print_" + cls.__name__
            if hasattr(self, methodName):
                return getattr(self, methodName)(node)
        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)

    def _print_KernelFunction(self, node):
115
        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
116
        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
117
118
119
        if self._signatureOnly:
            return funcDeclaration

120
121
        body = self._print(node.body)
        return funcDeclaration + "\n" + body
122
123

    def _print_Block(self, node):
124
        blockContents = "\n".join([self._print(child) for child in node.args])
Michael Kuron's avatar
Michael Kuron committed
125
        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)))
126
127

    def _print_PragmaBlock(self, node):
128
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
129
130
131

    def _print_LoopOverCoordinate(self, node):
        counterVar = node.loopCounterName
Martin Bauer's avatar
Martin Bauer committed
132
133
134
        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(node.step),)
135
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
136
137
138
139
140

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
141
142
143

    def _print_SympyAssignment(self, node):
        if node.isDeclaration:
144
145
146
147
148
149
150
            dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " "
            return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
        else:
            lhsType = getTypeOfExpression(node.lhs)
            if type(lhsType) is VectorType and node.lhs.func == castFunc:
                return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
                                                                   self.sympyPrinter.doprint(node.rhs)) + ';'
151
            else:
152
                return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
153
154

    def _print_TemporaryMemoryAllocation(self, node):
155
156
        return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
                                        node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size))
157
158

    def _print_TemporaryMemoryFree(self, node):
159
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
160
161

    def _print_CustomCppCode(self, node):
162
        return node.code
163

164
165
166
    def _print_Conditional(self, node):
        conditionExpr = self.sympyPrinter.doprint(node.conditionExpr)
        trueBlock = self._print_Block(node.trueBlock)
167
        result = "if (%s)\n%s " % (conditionExpr, trueBlock)
168
169
170
171
172
        if node.falseBlock:
            falseBlock = self._print_Block(node.falseBlock)
            result += "else " + falseBlock
        return result

173
174
175
176
177

# ------------------------------------------ Helper function & classes -------------------------------------------------


class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
178
179
180
181
182

    def __init__(self, constantsAsFloats=False):
        self._constantsAsFloats = constantsAsFloats
        super(CustomSympyPrinter, self).__init__()

183
184
185
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
186
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
187
188
189
190
191
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
192
193
194
195
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
196
197
198
199
200
201
202
203

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
204
205
206
207
208
209
210
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
211

212
    def _print_Function(self, expr):
213
        if expr.func == castFunc:
214
215
            arg, type = expr.args
            return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
Martin Bauer's avatar
Martin Bauer committed
216
217
218
219
220
221
        elif expr.func == xor:
            return "(%s ^ %s" % (self._print(expr.args[0]), self._print(expr.args[1]))
        elif expr.func == rightShift:
            return "(%s >> %s)" % (self._print(expr.args[0]), self._print(expr.args[1]))
        elif expr.func == leftShift:
            return "(%s << %s)" % (self._print(expr.args[0]), self._print(expr.args[1]))
222
        else:
223
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
224

225
226
227
228
229
230
231
232

class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

    def __init__(self, instructionSet, constantsAsFloats=False):
        super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
        self.instructionSet = instructionSet

233
234
235
236
237
238
239
240
    def _scalarFallback(self, funcName, expr, *args, **kwargs):
        exprType = getTypeOfExpression(expr)
        if type(exprType) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs)
        else:
            assert self.instructionSet['width'] == exprType.width
            return None

241
    def _print_Function(self, expr):
242
        if expr.func == castFunc:
243
244
245
246
247
248
249
250
251
            arg, dtype = expr.args
            if type(dtype) is VectorType:
                if type(arg) is ResolvedFieldAccess:
                    return self.instructionSet['loadU'].format("& " + self._print(arg))
                else:
                    return self.instructionSet['makeVec'].format(self._print(arg))

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

        argStrings = [self._print(a) for a in expr.args]
        assert len(argStrings) > 0
        result = argStrings[0]
        for item in argStrings[1:]:
            result = self.instructionSet['&'].format(result, item)
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

        argStrings = [self._print(a) for a in expr.args]
        assert len(argStrings) > 0
        result = argStrings[0]
        for item in argStrings[1:]:
            result = self.instructionSet['|'].format(result, item)
        return result

276
    def _print_Add(self, expr, order=None):
277
278
279
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
                sign, t = self._print_Mul(term, insideAdd=True)
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
            func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+']
            processed = func.format(processed, summand.term)
        return processed

302
    def _print_Pow(self, expr):
303
304
305
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
306
307
308
309
310
311
312
313
314
315
316
317

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
        else:
            if expr.exp == -1:
                one = self.instructionSet['makeVec'].format(1.0)
                return self.instructionSet['/'].format(one, self._print(expr.base))
            elif expr.exp == 0.5:
                return self.instructionSet['sqrt'].format(self._print(expr.base))
            else:
                raise ValueError("Generic exponential not supported")

318
    def _print_Mul(self, expr, insideAdd=False):
319
320
321
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]
344
        # a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
345
346
347
348
349
350

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
351
            result = self.instructionSet['*'].format(result, item)
352
353
354
355

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
356
357
                denominator_str = self.instructionSet['*'].format(denominator_str, item)
            result = self.instructionSet['/'].format(result, denominator_str)
358
359
360
361
362

        if insideAdd:
            return sign, result
        else:
            if sign < 0:
363
                return self.instructionSet['*'].format(self._print(S.NegativeOne), result)
364
365
366
            else:
                return result

367
    def _print_Relational(self, expr):
368
369
370
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
371
372
373
        return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))

    def _print_Equality(self, expr):
374
375
376
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
377
378
379
        return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))

    def _print_Piecewise(self, expr):
380
381
382
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
383
384
385
386
387
388
389
390
391
392
393
394
395
396

        if expr.args[-1].cond != True:
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
        for trueExpr, condition in reversed(expr.args[:-1]):
            result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
        return result