cbackend.py 18 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
from collections import namedtuple
from sympy.core import S
4
from typing import Set
5
from sympy.printing.ccode import C89CodePrinter
Martin Bauer's avatar
Martin Bauer committed
6
7
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
8
9
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
Martin Bauer's avatar
Martin Bauer committed
12
    bitwise_or, modulo_ceil
13
from pystencils.astnodes import Node, KernelFunction
Martin Bauer's avatar
Martin Bauer committed
14
15
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
    vector_memory_access
16

17
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
18

Martin Bauer's avatar
Martin Bauer committed
19

20
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
Martin Bauer's avatar
Martin Bauer committed
21
22
23
24
25
26
27
28
29
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:
30
        dialect: 'c' or 'cuda'
Martin Bauer's avatar
Martin Bauer committed
31
32
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
33
    """
34
    printer = CBackend(signature_only=signature_only,
35
36
                       vector_instruction_set=ast_node.instruction_set,
                       dialect=dialect)
Martin Bauer's avatar
Martin Bauer committed
37
    return printer(ast_node)
38
39


Martin Bauer's avatar
Martin Bauer committed
40
41
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
42
43
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
44
45
46
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
47
48
49
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
50
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
51
            headers.update(get_headers(a))
52
53

    return headers
54
55


56
57
58
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


59
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
60
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
61
        super(CustomCodeNode, self).__init__(parent=parent)
62
        self._code = "\n" + code
Martin Bauer's avatar
Martin Bauer committed
63
64
        self._symbolsRead = set(symbols_read)
        self._symbolsDefined = set(symbols_defined)
65
        self.headers = []
66

67
    def get_code(self, dialect, vector_instruction_set):
68
69
70
71
72
73
74
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
75
    def symbols_defined(self):
76
77
78
        return self._symbolsDefined

    @property
Martin Bauer's avatar
Martin Bauer committed
79
80
    def undefined_symbols(self):
        return self.symbols_defined - self._symbolsRead
81
82


83
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
84
85
86
87
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
88
        self.headers.append("<iostream>")
89
90
91
92


# ------------------------------------------- Printer ------------------------------------------------------------------

93

Martin Bauer's avatar
Martin Bauer committed
94
95
# noinspection PyPep8Naming
class CBackend:
96

97
    def __init__(self, sympy_printer=None,
98
                 signature_only=False, vector_instruction_set=None, dialect='c'):
Martin Bauer's avatar
Martin Bauer committed
99
100
        if sympy_printer is None:
            if vector_instruction_set is not None:
101
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
102
            else:
103
                self.sympy_printer = CustomSympyPrinter()
104
        else:
Martin Bauer's avatar
Martin Bauer committed
105
            self.sympy_printer = sympy_printer
106

107
        self._vector_instruction_set = vector_instruction_set
108
        self._indent = "   "
109
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
110
        self._signatureOnly = signature_only
111
112

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
113
        prev_is = VectorType.instruction_set
114
        VectorType.instruction_set = self._vector_instruction_set
115
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
116
        VectorType.instruction_set = prev_is
117
        return result
118
119
120

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
121
122
123
124
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
125
126

    def _print_KernelFunction(self, node):
127
        function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
Martin Bauer's avatar
Martin Bauer committed
128
        func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments))
129
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
130
            return func_declaration
131

132
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
133
        return func_declaration + "\n" + body
134
135

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
136
137
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
138
139

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
140
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
141
142

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
143
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
144
145
146
147
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
148

Martin Bauer's avatar
Martin Bauer committed
149
        prefix = "\n".join(node.prefix_lines)
150
151
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
152
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
153
154

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
155
156
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
157
158
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
159
        else:
Martin Bauer's avatar
Martin Bauer committed
160
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
161
162
163
164
165
166
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

167
168
169
170
171
172
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

173
174
                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
                                                                  self.sympy_printer.doprint(rhs)) + ';'
175
            else:
Martin Bauer's avatar
Martin Bauer committed
176
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
177
178

    def _print_TemporaryMemoryAllocation(self, node):
179
        align = 64
Martin Bauer's avatar
Martin Bauer committed
180
181
182
183
184
185
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
186
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
187
188
                           offset=int(node.offset(align)),
                           align=align)
189
190

    def _print_TemporaryMemoryFree(self, node):
191
        align = 64
Martin Bauer's avatar
Martin Bauer committed
192
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
193

194
195
    def _print_CustomCodeNode(self, node):
        return node.get_code(self._dialect, self._vector_instruction_set)
196

197
    def _print_Conditional(self, node):
Martin Bauer's avatar
Martin Bauer committed
198
199
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
200
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
201
202
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
203
            result += "else " + false_block
204
205
        return result

206
207
208
209

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
210
# noinspection PyPep8Naming
211
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
212

213
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
214
        super(CustomSympyPrinter, self).__init__()
215
        self._float_type = create_type("float32")
216
217
218
219
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
220

221
222
223
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
224
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
225
226
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
227
228
229
230
231
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
232
233
        res = str(expr.evalf().num)
        return res
234
235
236
237
238
239
240
241

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
242
243
        return result.replace("\n", "")

244
    def _print_Function(self, expr):
245
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
246
247
248
249
250
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
251
        }
Martin Bauer's avatar
Martin Bauer committed
252
253
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
254
        if isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
255
            arg, data_type = expr.args
256
257
258
            if isinstance(arg, sp.Number):
                return self._typed_number(arg, data_type)
            else:
259
                return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
260
261
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
262
        else:
263
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
264

265
266
    def _typed_number(self, number, dtype):
        res = self._print(number)
267
        if dtype.is_float():
268
269
270
271
272
273
274
275
            if dtype == self._float_type:
                if '.' not in res:
                    res += ".0f"
                else:
                    res += "f"
            return res
        else:
            return res
276

277
278
279
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

280

Martin Bauer's avatar
Martin Bauer committed
281
# noinspection PyPep8Naming
282
283
284
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

285
286
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
287
        self.instruction_set = instruction_set
288

Martin Bauer's avatar
Martin Bauer committed
289
290
291
292
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
293
        else:
Martin Bauer's avatar
Martin Bauer committed
294
            assert self.instruction_set['width'] == expr_type.width
295
296
            return None

297
    def _print_Function(self, expr):
298
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
299
300
301
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
302
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
303
304
            arg, data_type = expr.args
            if type(data_type) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
305
                return self.instruction_set['makeVec'].format(self._print(arg))
306
307
308

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

309
310
311
312
313
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
314
315
316
317
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
318
            result = self.instruction_set['&'].format(result, item)
319
320
321
322
323
324
325
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
326
327
328
329
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
330
            result = self.instruction_set['|'].format(result, item)
331
332
        return result

333
    def _print_Add(self, expr, order=None):
334
335
336
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
337
338
339
340

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
341
                sign, t = self._print_Mul(term, inside_add=True)
342
343
344
345
346
347
348
349
350
351
352
353
354
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
355
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
356
357
358
            processed = func.format(processed, summand.term)
        return processed

359
    def _print_Pow(self, expr):
360
361
362
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
363
364
365

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
366
367
368
369
370
371
372
373
374
        elif expr.exp == -1:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
375
        else:
376
            raise ValueError("Generic exponential not supported: " + str(expr))
377

Martin Bauer's avatar
Martin Bauer committed
378
379
380
381
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

382
383
384
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
413
            result = self.instruction_set['*'].format(result, item)
414
415
416
417

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
418
419
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
420

Martin Bauer's avatar
Martin Bauer committed
421
        if inside_add:
422
423
424
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
425
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
426
427
428
            else:
                return result

429
    def _print_Relational(self, expr):
430
431
432
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
433
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
434
435

    def _print_Equality(self, expr):
436
437
438
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
439
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
440
441

    def _print_Piecewise(self, expr):
442
443
444
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
445

Martin Bauer's avatar
Martin Bauer committed
446
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
447
448
449
450
451
452
453
454
455
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
456
        for true_expr, condition in reversed(expr.args[:-1]):
Martin Bauer's avatar
Martin Bauer committed
457
            # noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
458
            result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
459
        return result