cbackend.py 17.6 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
from collections import namedtuple
from sympy.core import S
4
from typing import Set
5
from sympy.printing.ccode import C89CodePrinter
Martin Bauer's avatar
Martin Bauer committed
6
7
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
8
9
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
Martin Bauer's avatar
Martin Bauer committed
12
    bitwise_or, modulo_ceil
13
from pystencils.astnodes import Node, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
14
15
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
    vector_memory_access
16

Martin Bauer's avatar
Martin Bauer committed
17
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
18

Martin Bauer's avatar
Martin Bauer committed
19

20
def generate_c(ast_node: Node, signature_only: bool = False) -> str:
Martin Bauer's avatar
Martin Bauer committed
21
22
23
24
25
26
27
28
29
30
31
32
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:

    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
33
    """
34
    printer = CBackend(signature_only=signature_only,
Martin Bauer's avatar
Martin Bauer committed
35
                       vector_instruction_set=ast_node.instruction_set)
Martin Bauer's avatar
Martin Bauer committed
36
    return printer(ast_node)
37
38


Martin Bauer's avatar
Martin Bauer committed
39
40
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
41
42
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
43
44
45
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
46
47
48
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
49
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
50
            headers.update(get_headers(a))
51
52

    return headers
53
54


55
56
57
58
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
Martin Bauer's avatar
Martin Bauer committed
59
60
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
        super(CustomCppCode, self).__init__(parent=parent)
61
        self._code = "\n" + code
Martin Bauer's avatar
Martin Bauer committed
62
63
        self._symbolsRead = set(symbols_read)
        self._symbolsDefined = set(symbols_defined)
64
        self.headers = []
65
66
67
68
69
70
71
72
73
74

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
75
    def symbols_defined(self):
76
77
78
        return self._symbolsDefined

    @property
Martin Bauer's avatar
Martin Bauer committed
79
80
    def undefined_symbols(self):
        return self.symbols_defined - self._symbolsRead
81
82
83


class PrintNode(CustomCppCode):
Martin Bauer's avatar
Martin Bauer committed
84
85
86
87
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
88
        self.headers.append("<iostream>")
89
90
91
92


# ------------------------------------------- Printer ------------------------------------------------------------------

93

Martin Bauer's avatar
Martin Bauer committed
94
95
# noinspection PyPep8Naming
class CBackend:
96

97
    def __init__(self, sympy_printer=None,
Martin Bauer's avatar
Martin Bauer committed
98
99
100
                 signature_only=False, vector_instruction_set=None):
        if sympy_printer is None:
            if vector_instruction_set is not None:
101
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
102
            else:
103
                self.sympy_printer = CustomSympyPrinter()
104
        else:
Martin Bauer's avatar
Martin Bauer committed
105
            self.sympy_printer = sympy_printer
106

Martin Bauer's avatar
Martin Bauer committed
107
        self._vectorInstructionSet = vector_instruction_set
108
        self._indent = "   "
Martin Bauer's avatar
Martin Bauer committed
109
        self._signatureOnly = signature_only
110
111

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
112
113
        prev_is = VectorType.instruction_set
        VectorType.instruction_set = self._vectorInstructionSet
114
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
115
        VectorType.instruction_set = prev_is
116
        return result
117
118
119

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
120
121
122
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
123

Martin Bauer's avatar
Martin Bauer committed
124
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
125
126

    def _print_KernelFunction(self, node):
127
        function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
Martin Bauer's avatar
Martin Bauer committed
128
        func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments))
129
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
130
            return func_declaration
131

132
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
133
        return func_declaration + "\n" + body
134
135

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
136
137
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
138
139

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
140
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
141
142

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
143
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
144
145
146
147
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
148

Martin Bauer's avatar
Martin Bauer committed
149
        prefix = "\n".join(node.prefix_lines)
150
151
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
152
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
153
154

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
155
156
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
157
158
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
159
        else:
Martin Bauer's avatar
Martin Bauer committed
160
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
167
168
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

                return self._vectorInstructionSet[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
                                                                self.sympy_printer.doprint(node.rhs)) + ';'
169
            else:
Martin Bauer's avatar
Martin Bauer committed
170
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
171
172

    def _print_TemporaryMemoryAllocation(self, node):
173
        align = 64
Martin Bauer's avatar
Martin Bauer committed
174
175
176
177
178
179
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
180
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
181
182
                           offset=int(node.offset(align)),
                           align=align)
183
184

    def _print_TemporaryMemoryFree(self, node):
185
        align = 64
Martin Bauer's avatar
Martin Bauer committed
186
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
187

Martin Bauer's avatar
Martin Bauer committed
188
189
    @staticmethod
    def _print_CustomCppCode(node):
190
        return node.code
191

192
    def _print_Conditional(self, node):
Martin Bauer's avatar
Martin Bauer committed
193
194
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
195
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
196
197
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
198
            result += "else " + false_block
199
200
        return result

201
202
203
204

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
205
# noinspection PyPep8Naming
206
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
207

208
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
209
        super(CustomSympyPrinter, self).__init__()
210
        self._float_type = create_type("float32")
211
212
213
214
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
215

216
217
218
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
219
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
220
221
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
222
223
224
225
226
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
227
228
        res = str(expr.evalf().num)
        return res
229
230
231
232
233
234
235
236

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
237
238
        return result.replace("\n", "")

239
    def _print_Function(self, expr):
240
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
241
242
243
244
245
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
246
        }
Martin Bauer's avatar
Martin Bauer committed
247
248
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
249
        if isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
250
            arg, data_type = expr.args
251
252
253
            if isinstance(arg, sp.Number):
                return self._typed_number(arg, data_type)
            else:
254
                return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
255
256
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
257
        else:
258
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
259

260
261
    def _typed_number(self, number, dtype):
        res = self._print(number)
262
        if dtype.is_float():
263
264
265
266
267
268
269
270
            if dtype == self._float_type:
                if '.' not in res:
                    res += ".0f"
                else:
                    res += "f"
            return res
        else:
            return res
271

272
273
274
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

275

Martin Bauer's avatar
Martin Bauer committed
276
# noinspection PyPep8Naming
277
278
279
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

280
281
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
282
        self.instruction_set = instruction_set
283

Martin Bauer's avatar
Martin Bauer committed
284
285
286
287
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
288
        else:
Martin Bauer's avatar
Martin Bauer committed
289
            assert self.instruction_set['width'] == expr_type.width
290
291
            return None

292
    def _print_Function(self, expr):
293
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
294
295
296
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
297
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
298
299
            arg, data_type = expr.args
            if type(data_type) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
300
                return self.instruction_set['makeVec'].format(self._print(arg))
301
302
303

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

304
305
306
307
308
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
309
310
311
312
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
313
            result = self.instruction_set['&'].format(result, item)
314
315
316
317
318
319
320
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
321
322
323
324
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
325
            result = self.instruction_set['|'].format(result, item)
326
327
        return result

328
    def _print_Add(self, expr, order=None):
329
330
331
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
332
333
334
335

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
336
                sign, t = self._print_Mul(term, inside_add=True)
337
338
339
340
341
342
343
344
345
346
347
348
349
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
350
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
351
352
353
            processed = func.format(processed, summand.term)
        return processed

354
    def _print_Pow(self, expr):
355
356
357
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
358
359
360

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
361
362
363
364
365
366
367
368
369
        elif expr.exp == -1:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
370
        else:
371
            raise ValueError("Generic exponential not supported: " + str(expr))
372

Martin Bauer's avatar
Martin Bauer committed
373
374
375
376
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

377
378
379
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
408
            result = self.instruction_set['*'].format(result, item)
409
410
411
412

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
413
414
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
415

Martin Bauer's avatar
Martin Bauer committed
416
        if inside_add:
417
418
419
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
420
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
421
422
423
            else:
                return result

424
    def _print_Relational(self, expr):
425
426
427
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
428
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
429
430

    def _print_Equality(self, expr):
431
432
433
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
434
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
435
436

    def _print_Piecewise(self, expr):
437
438
439
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
440

Martin Bauer's avatar
Martin Bauer committed
441
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
442
443
444
445
446
447
448
449
450
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
451
        for true_expr, condition in reversed(expr.args[:-1]):
Martin Bauer's avatar
Martin Bauer committed
452
            # noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
453
            result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
454
        return result