cbackend.py 16.2 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
4
from collections import namedtuple
from sympy.core import S
from typing import Optional
Martin Bauer's avatar
Martin Bauer committed
5

Martin Bauer's avatar
Martin Bauer committed
6
7
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
8
9
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.bitoperations import xor, rightShift, leftShift, bitwiseAnd, bitwiseOr
12
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, castFunc
14
from pystencils.backends.simd_instruction_sets import selectedInstructionSet
15

Martin Bauer's avatar
Martin Bauer committed
16
__all__ = ['print_c']
17

Martin Bauer's avatar
Martin Bauer committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

def print_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str:
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:
        use_float_constants:

    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
33
    """
Martin Bauer's avatar
Martin Bauer committed
34
35
36
37
    if use_float_constants is None:
        field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess))
        double = create_type('double')
        use_float_constants = double not in field_types
38

Martin Bauer's avatar
Martin Bauer committed
39
40
41
42
    vector_is = selectedInstructionSet['double']
    printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
                       vector_instruction_set=vector_is)
    return printer(ast_node)
43
44


Martin Bauer's avatar
Martin Bauer committed
45
def get_headers(ast_node):
46
47
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
48
49
50
51
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    elif isinstance(ast_node, SympyAssignment):
        if type(get_type_of_expression(ast_node.rhs)) is VectorType:
52
53
            headers.update(selectedInstructionSet['double']['headers'])

Martin Bauer's avatar
Martin Bauer committed
54
    for a in ast_node.args:
55
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
56
            headers.update(get_headers(a))
57
58

    return headers
59
60


61
62
63
64
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
Martin Bauer's avatar
Martin Bauer committed
65
66
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
        super(CustomCppCode, self).__init__(parent=parent)
67
        self._code = "\n" + code
Martin Bauer's avatar
Martin Bauer committed
68
69
        self._symbolsRead = set(symbols_read)
        self._symbolsDefined = set(symbols_defined)
70
        self.headers = []
71
72
73
74
75
76
77
78
79
80

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
81
    def symbols_defined(self):
82
83
84
        return self._symbolsDefined

    @property
Martin Bauer's avatar
Martin Bauer committed
85
86
    def undefined_symbols(self):
        return self.symbols_defined - self._symbolsRead
87
88
89


class PrintNode(CustomCppCode):
Martin Bauer's avatar
Martin Bauer committed
90
91
92
93
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
94
        self.headers.append("<iostream>")
95
96
97
98


# ------------------------------------------- Printer ------------------------------------------------------------------

99

Martin Bauer's avatar
Martin Bauer committed
100
101
# noinspection PyPep8Naming
class CBackend:
102

Martin Bauer's avatar
Martin Bauer committed
103
104
105
106
107
108
    def __init__(self, constants_as_floats=False, sympy_printer=None,
                 signature_only=False, vector_instruction_set=None):
        if sympy_printer is None:
            self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
            if vector_instruction_set is not None:
                self.sympyPrinter = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
109
            else:
Martin Bauer's avatar
Martin Bauer committed
110
                self.sympyPrinter = CustomSympyPrinter(constants_as_floats)
111
        else:
Martin Bauer's avatar
Martin Bauer committed
112
            self.sympyPrinter = sympy_printer
113

Martin Bauer's avatar
Martin Bauer committed
114
        self._vectorInstructionSet = vector_instruction_set
115
        self._indent = "   "
Martin Bauer's avatar
Martin Bauer committed
116
        self._signatureOnly = signature_only
117
118

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
119
        prev_is = VectorType.instructionSet
120
121
        VectorType.instructionSet = self._vectorInstructionSet
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
122
        VectorType.instructionSet = prev_is
123
        return result
124
125
126

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
127
128
129
130
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
131
132

    def _print_KernelFunction(self, node):
Martin Bauer's avatar
Martin Bauer committed
133
134
        function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
        func_declaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(function_arguments))
135
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
136
            return func_declaration
137

138
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
139
        return func_declaration + "\n" + body
140
141

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
142
143
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
144
145

    def _print_PragmaBlock(self, node):
146
        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
147
148

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
149
150
151
152
        counter_symbol = node.loop_counter_name
        start = "int %s = %s" % (counter_symbol, self.sympyPrinter.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympyPrinter.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympyPrinter.doprint(node.step),)
153
        loopStr = "for (%s; %s; %s)" % (start, condition, update)
154
155
156
157
158

        prefix = "\n".join(node.prefixLines)
        if prefix:
            prefix += "\n"
        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
159
160

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
161
162
163
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
            return "%s %s = %s;" % (data_type, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
164
        else:
Martin Bauer's avatar
Martin Bauer committed
165
166
            lhs_type = get_type_of_expression(node.lhs)
            if type(lhs_type) is VectorType and node.lhs.func == castFunc:
167
168
                return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
                                                                   self.sympyPrinter.doprint(node.rhs)) + ';'
169
            else:
170
                return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
171
172

    def _print_TemporaryMemoryAllocation(self, node):
173
        return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
Martin Bauer's avatar
Martin Bauer committed
174
                                        node.symbol.dtype.base_type, self.sympyPrinter.doprint(node.size))
175
176

    def _print_TemporaryMemoryFree(self, node):
177
        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
178

Martin Bauer's avatar
Martin Bauer committed
179
180
    @staticmethod
    def _print_CustomCppCode(node):
181
        return node.code
182

183
    def _print_Conditional(self, node):
Martin Bauer's avatar
Martin Bauer committed
184
185
186
        condition_expr = self.sympyPrinter.doprint(node.conditionExpr)
        true_block = self._print_Block(node.trueBlock)
        result = "if (%s)\n%s " % (condition_expr, true_block)
187
        if node.falseBlock:
Martin Bauer's avatar
Martin Bauer committed
188
189
            false_block = self._print_Block(node.falseBlock)
            result += "else " + false_block
190
191
        return result

192
193
194
195

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
196
# noinspection PyPep8Naming
197
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
198

Martin Bauer's avatar
Martin Bauer committed
199
200
    def __init__(self, constants_as_floats=False):
        self._constantsAsFloats = constants_as_floats
Martin Bauer's avatar
Martin Bauer committed
201
202
        super(CustomSympyPrinter, self).__init__()

203
204
205
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
206
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
207
208
209
210
211
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
212
213
214
215
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
216
217
218
219
220
221
222
223

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
224
225
226
227
228
229
230
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
231

232
    def _print_Function(self, expr):
Martin Bauer's avatar
Martin Bauer committed
233
        function_map = {
Martin Bauer's avatar
Martin Bauer committed
234
235
236
237
238
239
            xor: '^',
            rightShift: '>>',
            leftShift: '<<',
            bitwiseOr: '|',
            bitwiseAnd: '&',
        }
240
        if expr.func == castFunc:
Martin Bauer's avatar
Martin Bauer committed
241
242
243
244
            arg, data_type = expr.args
            return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
        elif expr.func in function_map:
            return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1]))
245
        else:
246
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
247

248

Martin Bauer's avatar
Martin Bauer committed
249
# noinspection PyPep8Naming
250
251
252
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

Martin Bauer's avatar
Martin Bauer committed
253
254
255
    def __init__(self, instruction_set, constants_as_floats=False):
        super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats)
        self.instructionSet = instruction_set
256

Martin Bauer's avatar
Martin Bauer committed
257
258
259
260
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
261
        else:
Martin Bauer's avatar
Martin Bauer committed
262
            assert self.instructionSet['width'] == expr_type.width
263
264
            return None

265
    def _print_Function(self, expr):
266
        if expr.func == castFunc:
Martin Bauer's avatar
Martin Bauer committed
267
268
            arg, data_type = expr.args
            if type(data_type) is VectorType:
269
270
271
272
273
274
275
                if type(arg) is ResolvedFieldAccess:
                    return self.instructionSet['loadU'].format("& " + self._print(arg))
                else:
                    return self.instructionSet['makeVec'].format(self._print(arg))

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

276
277
278
279
280
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
281
282
283
284
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
285
286
287
288
289
290
291
292
            result = self.instructionSet['&'].format(result, item)
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
293
294
295
296
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
297
298
299
            result = self.instructionSet['|'].format(result, item)
        return result

300
    def _print_Add(self, expr, order=None):
301
302
303
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
304
305
306
307

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
308
                sign, t = self._print_Mul(term, inside_add=True)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
            func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+']
            processed = func.format(processed, summand.term)
        return processed

326
    def _print_Pow(self, expr):
327
328
329
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
330
331
332
333
334
335
336
337
338
339
340
341

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
        else:
            if expr.exp == -1:
                one = self.instructionSet['makeVec'].format(1.0)
                return self.instructionSet['/'].format(one, self._print(expr.base))
            elif expr.exp == 0.5:
                return self.instructionSet['sqrt'].format(self._print(expr.base))
            else:
                raise ValueError("Generic exponential not supported")

Martin Bauer's avatar
Martin Bauer committed
342
343
344
345
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

346
347
348
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]
371
        # a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
372
373
374
375
376
377

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
378
            result = self.instructionSet['*'].format(result, item)
379
380
381
382

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
383
384
                denominator_str = self.instructionSet['*'].format(denominator_str, item)
            result = self.instructionSet['/'].format(result, denominator_str)
385

Martin Bauer's avatar
Martin Bauer committed
386
        if inside_add:
387
388
389
            return sign, result
        else:
            if sign < 0:
390
                return self.instructionSet['*'].format(self._print(S.NegativeOne), result)
391
392
393
            else:
                return result

394
    def _print_Relational(self, expr):
395
396
397
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
398
399
400
        return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))

    def _print_Equality(self, expr):
401
402
403
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
404
405
406
        return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))

    def _print_Piecewise(self, expr):
407
408
409
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
410

Martin Bauer's avatar
Martin Bauer committed
411
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
412
413
414
415
416
417
418
419
420
421
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
        for trueExpr, condition in reversed(expr.args[:-1]):
Martin Bauer's avatar
Martin Bauer committed
422
            # noinspection SpellCheckingInspection
423
424
            result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
        return result