cbackend.py 17.4 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
from collections import namedtuple
from sympy.core import S
4
from typing import Set
5
from sympy.printing.ccode import C89CodePrinter
Martin Bauer's avatar
Martin Bauer committed
6
7
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
8
9
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
Martin Bauer's avatar
Martin Bauer committed
12
    bitwise_or, modulo_floor, modulo_ceil
13
from pystencils.astnodes import Node, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
14
15
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
    vector_memory_access
16

Martin Bauer's avatar
Martin Bauer committed
17
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
18

Martin Bauer's avatar
Martin Bauer committed
19

20
def generate_c(ast_node: Node, signature_only: bool = False) -> str:
Martin Bauer's avatar
Martin Bauer committed
21
22
23
24
25
26
27
28
29
30
31
32
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:

    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
33
    """
34
    printer = CBackend(signature_only=signature_only,
Martin Bauer's avatar
Martin Bauer committed
35
                       vector_instruction_set=ast_node.instruction_set)
Martin Bauer's avatar
Martin Bauer committed
36
    return printer(ast_node)
37
38


Martin Bauer's avatar
Martin Bauer committed
39
40
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
41
42
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
43
44
45
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
46
47
48
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
49
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
50
            headers.update(get_headers(a))
51
52

    return headers
53
54


55
56
57
58
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
Martin Bauer's avatar
Martin Bauer committed
59
60
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
        super(CustomCppCode, self).__init__(parent=parent)
61
        self._code = "\n" + code
Martin Bauer's avatar
Martin Bauer committed
62
63
        self._symbolsRead = set(symbols_read)
        self._symbolsDefined = set(symbols_defined)
64
        self.headers = []
65
66
67
68
69
70
71
72
73
74

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
75
    def symbols_defined(self):
76
77
78
        return self._symbolsDefined

    @property
Martin Bauer's avatar
Martin Bauer committed
79
80
    def undefined_symbols(self):
        return self.symbols_defined - self._symbolsRead
81
82
83


class PrintNode(CustomCppCode):
Martin Bauer's avatar
Martin Bauer committed
84
85
86
87
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
88
        self.headers.append("<iostream>")
89
90
91
92


# ------------------------------------------- Printer ------------------------------------------------------------------

93

Martin Bauer's avatar
Martin Bauer committed
94
95
# noinspection PyPep8Naming
class CBackend:
96

97
    def __init__(self, sympy_printer=None,
Martin Bauer's avatar
Martin Bauer committed
98
99
100
                 signature_only=False, vector_instruction_set=None):
        if sympy_printer is None:
            if vector_instruction_set is not None:
101
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
102
            else:
103
                self.sympy_printer = CustomSympyPrinter()
104
        else:
Martin Bauer's avatar
Martin Bauer committed
105
            self.sympy_printer = sympy_printer
106

Martin Bauer's avatar
Martin Bauer committed
107
        self._vectorInstructionSet = vector_instruction_set
108
        self._indent = "   "
Martin Bauer's avatar
Martin Bauer committed
109
        self._signatureOnly = signature_only
110
111

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
112
113
        prev_is = VectorType.instruction_set
        VectorType.instruction_set = self._vectorInstructionSet
114
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
115
        VectorType.instruction_set = prev_is
116
        return result
117
118
119

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
120
121
122
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
123

Martin Bauer's avatar
Martin Bauer committed
124
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
125
126

    def _print_KernelFunction(self, node):
Martin Bauer's avatar
Martin Bauer committed
127
        function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
128
        func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments))
129
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
130
            return func_declaration
131

132
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
133
        return func_declaration + "\n" + body
134
135

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
136
137
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
138
139

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
140
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
141
142

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
143
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
144
145
146
147
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
148

Martin Bauer's avatar
Martin Bauer committed
149
        prefix = "\n".join(node.prefix_lines)
150
151
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
152
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
153
154

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
155
156
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
157
158
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
159
        else:
Martin Bauer's avatar
Martin Bauer committed
160
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
167
168
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

                return self._vectorInstructionSet[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
                                                                self.sympy_printer.doprint(node.rhs)) + ';'
169
            else:
Martin Bauer's avatar
Martin Bauer committed
170
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
171
172

    def _print_TemporaryMemoryAllocation(self, node):
173
        align = 64
Martin Bauer's avatar
Martin Bauer committed
174
175
176
177
178
179
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
180
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
181
182
                           offset=int(node.offset(align)),
                           align=align)
183
184

    def _print_TemporaryMemoryFree(self, node):
185
        align = 64
Martin Bauer's avatar
Martin Bauer committed
186
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
187

Martin Bauer's avatar
Martin Bauer committed
188
189
    @staticmethod
    def _print_CustomCppCode(node):
190
        return node.code
191

192
    def _print_Conditional(self, node):
Martin Bauer's avatar
Martin Bauer committed
193
194
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
195
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
196
197
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
198
            result += "else " + false_block
199
200
        return result

201
202
203
204

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
205
# noinspection PyPep8Naming
206
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
207

208
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
209
        super(CustomSympyPrinter, self).__init__()
210
        self._float_type = create_type("float32")
Martin Bauer's avatar
Martin Bauer committed
211

212
213
214
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
215
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
216
217
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
218
219
220
221
222
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
223
224
        res = str(expr.evalf().num)
        return res
225
226
227
228
229
230
231
232

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
233
234
        return result.replace("\n", "")

235
    def _print_Function(self, expr):
236
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
237
238
239
240
241
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
242
        }
Martin Bauer's avatar
Martin Bauer committed
243
244
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
Martin Bauer's avatar
Martin Bauer committed
245
        if expr.func == cast_func:
Martin Bauer's avatar
Martin Bauer committed
246
            arg, data_type = expr.args
247
248
249
250
            if isinstance(arg, sp.Number):
                return self._typed_number(arg, data_type)
            else:
                return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
251
252
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
253
        else:
254
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
255

256
257
    def _typed_number(self, number, dtype):
        res = self._print(number)
258
        if dtype.is_float():
259
260
261
262
263
264
265
266
            if dtype == self._float_type:
                if '.' not in res:
                    res += ".0f"
                else:
                    res += "f"
            return res
        else:
            return res
267

268
269
270
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

271

Martin Bauer's avatar
Martin Bauer committed
272
# noinspection PyPep8Naming
273
274
275
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

276
277
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
278
        self.instruction_set = instruction_set
279

Martin Bauer's avatar
Martin Bauer committed
280
281
282
283
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
284
        else:
Martin Bauer's avatar
Martin Bauer committed
285
            assert self.instruction_set['width'] == expr_type.width
286
287
            return None

288
    def _print_Function(self, expr):
Martin Bauer's avatar
Martin Bauer committed
289
290
291
292
293
        if expr.func == vector_memory_access:
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
        elif expr.func == cast_func:
Martin Bauer's avatar
Martin Bauer committed
294
295
            arg, data_type = expr.args
            if type(data_type) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
296
                return self.instruction_set['makeVec'].format(self._print(arg))
297
298
299

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

300
301
302
303
304
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
305
306
307
308
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
309
            result = self.instruction_set['&'].format(result, item)
310
311
312
313
314
315
316
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
317
318
319
320
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
321
            result = self.instruction_set['|'].format(result, item)
322
323
        return result

324
    def _print_Add(self, expr, order=None):
325
326
327
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
328
329
330
331

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
332
                sign, t = self._print_Mul(term, inside_add=True)
333
334
335
336
337
338
339
340
341
342
343
344
345
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
346
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
347
348
349
            processed = func.format(processed, summand.term)
        return processed

350
    def _print_Pow(self, expr):
351
352
353
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
354
355
356

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
357
358
359
360
361
362
363
364
365
        elif expr.exp == -1:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
366
        else:
367
            raise ValueError("Generic exponential not supported: " + str(expr))
368

Martin Bauer's avatar
Martin Bauer committed
369
370
371
372
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

373
374
375
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
404
            result = self.instruction_set['*'].format(result, item)
405
406
407
408

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
409
410
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
411

Martin Bauer's avatar
Martin Bauer committed
412
        if inside_add:
413
414
415
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
416
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
417
418
419
            else:
                return result

420
    def _print_Relational(self, expr):
421
422
423
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
424
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
425
426

    def _print_Equality(self, expr):
427
428
429
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
430
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
431
432

    def _print_Piecewise(self, expr):
433
434
435
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
436

Martin Bauer's avatar
Martin Bauer committed
437
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
438
439
440
441
442
443
444
445
446
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
447
        for true_expr, condition in reversed(expr.args[:-1]):
Martin Bauer's avatar
Martin Bauer committed
448
            # noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
449
            result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
450
        return result