cbackend.py 17.8 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
from collections import namedtuple
from sympy.core import S
4
from typing import Set
5
from sympy.printing.ccode import C89CodePrinter
Martin Bauer's avatar
Martin Bauer committed
6
7
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
8
9
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
Martin Bauer's avatar
Martin Bauer committed
12
    bitwise_or, modulo_ceil
13
from pystencils.astnodes import Node, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
14
15
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
    vector_memory_access
16

Martin Bauer's avatar
Martin Bauer committed
17
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
18

Martin Bauer's avatar
Martin Bauer committed
19

20
def generate_c(ast_node: Node, signature_only: bool = False) -> str:
Martin Bauer's avatar
Martin Bauer committed
21
22
23
24
25
26
27
28
29
30
31
32
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:

    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
33
    """
34
    printer = CBackend(signature_only=signature_only,
Martin Bauer's avatar
Martin Bauer committed
35
                       vector_instruction_set=ast_node.instruction_set)
Martin Bauer's avatar
Martin Bauer committed
36
    return printer(ast_node)
37
38


Martin Bauer's avatar
Martin Bauer committed
39
40
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
41
42
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
43
44
45
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
46
47
48
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
49
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
50
            headers.update(get_headers(a))
51
52

    return headers
53
54


55
56
57
58
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
Martin Bauer's avatar
Martin Bauer committed
59
60
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
        super(CustomCppCode, self).__init__(parent=parent)
61
        self._code = "\n" + code
Martin Bauer's avatar
Martin Bauer committed
62
63
        self._symbolsRead = set(symbols_read)
        self._symbolsDefined = set(symbols_defined)
64
        self.headers = []
65
66
67
68
69
70
71
72
73
74

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
75
    def symbols_defined(self):
76
77
78
        return self._symbolsDefined

    @property
Martin Bauer's avatar
Martin Bauer committed
79
80
    def undefined_symbols(self):
        return self.symbols_defined - self._symbolsRead
81
82
83


class PrintNode(CustomCppCode):
Martin Bauer's avatar
Martin Bauer committed
84
85
86
87
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
88
        self.headers.append("<iostream>")
89
90
91
92


# ------------------------------------------- Printer ------------------------------------------------------------------

93

Martin Bauer's avatar
Martin Bauer committed
94
95
# noinspection PyPep8Naming
class CBackend:
96

97
    def __init__(self, sympy_printer=None,
Martin Bauer's avatar
Martin Bauer committed
98
99
100
                 signature_only=False, vector_instruction_set=None):
        if sympy_printer is None:
            if vector_instruction_set is not None:
101
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
102
            else:
103
                self.sympy_printer = CustomSympyPrinter()
104
        else:
Martin Bauer's avatar
Martin Bauer committed
105
            self.sympy_printer = sympy_printer
106

Martin Bauer's avatar
Martin Bauer committed
107
        self._vectorInstructionSet = vector_instruction_set
108
        self._indent = "   "
Martin Bauer's avatar
Martin Bauer committed
109
        self._signatureOnly = signature_only
110
111

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
112
113
        prev_is = VectorType.instruction_set
        VectorType.instruction_set = self._vectorInstructionSet
114
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
115
        VectorType.instruction_set = prev_is
116
        return result
117
118
119

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
120
121
122
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
123

Martin Bauer's avatar
Martin Bauer committed
124
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
125
126

    def _print_KernelFunction(self, node):
127
        function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
Martin Bauer's avatar
Martin Bauer committed
128
        func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments))
129
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
130
            return func_declaration
131

132
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
133
        return func_declaration + "\n" + body
134
135

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
136
137
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
138
139

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
140
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
141
142

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
143
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
144
145
146
147
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
148

Martin Bauer's avatar
Martin Bauer committed
149
        prefix = "\n".join(node.prefix_lines)
150
151
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
152
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
153
154

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
155
156
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
157
158
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
159
        else:
Martin Bauer's avatar
Martin Bauer committed
160
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

167
168
169
170
171
172
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

Martin Bauer's avatar
Martin Bauer committed
173
                return self._vectorInstructionSet[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
174
                                                                self.sympy_printer.doprint(rhs)) + ';'
175
            else:
Martin Bauer's avatar
Martin Bauer committed
176
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
177
178

    def _print_TemporaryMemoryAllocation(self, node):
179
        align = 64
Martin Bauer's avatar
Martin Bauer committed
180
181
182
183
184
185
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
186
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
187
188
                           offset=int(node.offset(align)),
                           align=align)
189
190

    def _print_TemporaryMemoryFree(self, node):
191
        align = 64
Martin Bauer's avatar
Martin Bauer committed
192
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
193

Martin Bauer's avatar
Martin Bauer committed
194
195
    @staticmethod
    def _print_CustomCppCode(node):
196
        return node.code
197

198
    def _print_Conditional(self, node):
Martin Bauer's avatar
Martin Bauer committed
199
200
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
201
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
202
203
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
204
            result += "else " + false_block
205
206
        return result

207
208
209
210

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
211
# noinspection PyPep8Naming
212
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
213

214
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
215
        super(CustomSympyPrinter, self).__init__()
216
        self._float_type = create_type("float32")
217
218
219
220
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
221

222
223
224
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
225
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
226
227
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
228
229
230
231
232
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
233
234
        res = str(expr.evalf().num)
        return res
235
236
237
238
239
240
241
242

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
243
244
        return result.replace("\n", "")

245
    def _print_Function(self, expr):
246
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
247
248
249
250
251
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
252
        }
Martin Bauer's avatar
Martin Bauer committed
253
254
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
255
        if isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
256
            arg, data_type = expr.args
257
258
259
            if isinstance(arg, sp.Number):
                return self._typed_number(arg, data_type)
            else:
260
                return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
261
262
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
263
        else:
264
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
265

266
267
    def _typed_number(self, number, dtype):
        res = self._print(number)
268
        if dtype.is_float():
269
270
271
272
273
274
275
276
            if dtype == self._float_type:
                if '.' not in res:
                    res += ".0f"
                else:
                    res += "f"
            return res
        else:
            return res
277

278
279
280
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

281

Martin Bauer's avatar
Martin Bauer committed
282
# noinspection PyPep8Naming
283
284
285
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

286
287
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
288
        self.instruction_set = instruction_set
289

Martin Bauer's avatar
Martin Bauer committed
290
291
292
293
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
294
        else:
Martin Bauer's avatar
Martin Bauer committed
295
            assert self.instruction_set['width'] == expr_type.width
296
297
            return None

298
    def _print_Function(self, expr):
299
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
300
301
302
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
303
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
304
305
            arg, data_type = expr.args
            if type(data_type) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
306
                return self.instruction_set['makeVec'].format(self._print(arg))
307
308
309

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

310
311
312
313
314
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
315
316
317
318
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
319
            result = self.instruction_set['&'].format(result, item)
320
321
322
323
324
325
326
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
327
328
329
330
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
331
            result = self.instruction_set['|'].format(result, item)
332
333
        return result

334
    def _print_Add(self, expr, order=None):
335
336
337
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
338
339
340
341

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
342
                sign, t = self._print_Mul(term, inside_add=True)
343
344
345
346
347
348
349
350
351
352
353
354
355
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
356
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
357
358
359
            processed = func.format(processed, summand.term)
        return processed

360
    def _print_Pow(self, expr):
361
362
363
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
364
365
366

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
367
368
369
370
371
372
373
374
375
        elif expr.exp == -1:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
376
        else:
377
            raise ValueError("Generic exponential not supported: " + str(expr))
378

Martin Bauer's avatar
Martin Bauer committed
379
380
381
382
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

383
384
385
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
414
            result = self.instruction_set['*'].format(result, item)
415
416
417
418

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
419
420
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
421

Martin Bauer's avatar
Martin Bauer committed
422
        if inside_add:
423
424
425
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
426
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
427
428
429
            else:
                return result

430
    def _print_Relational(self, expr):
431
432
433
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
434
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
435
436

    def _print_Equality(self, expr):
437
438
439
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
440
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
441
442

    def _print_Piecewise(self, expr):
443
444
445
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
446

Martin Bauer's avatar
Martin Bauer committed
447
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
448
449
450
451
452
453
454
455
456
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
457
        for true_expr, condition in reversed(expr.args[:-1]):
Martin Bauer's avatar
Martin Bauer committed
458
            # noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
459
            result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
460
        return result