cbackend.py 16.5 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
from collections import namedtuple
from sympy.core import S
Martin Bauer's avatar
Martin Bauer committed
4
from typing import Optional, Set
Martin Bauer's avatar
Martin Bauer committed
5

Martin Bauer's avatar
Martin Bauer committed
6
7
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
8
9
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
10

Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.bitoperations import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, bitwise_or
12
from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func
Martin Bauer's avatar
Martin Bauer committed
14
from pystencils.backends.simd_instruction_sets import selected_instruction_set
15

16
__all__ = ['generate_c', 'CustomCppCode', 'PrintNode', 'get_headers']
17

Martin Bauer's avatar
Martin Bauer committed
18

Martin Bauer's avatar
Martin Bauer committed
19
def generate_c(ast_node: Node, signature_only: bool = False, use_float_constants: Optional[bool] = None) -> str:
Martin Bauer's avatar
Martin Bauer committed
20
21
22
23
24
25
26
27
28
29
30
31
32
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:
        use_float_constants:

    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
33
    """
Martin Bauer's avatar
Martin Bauer committed
34
35
36
37
    if use_float_constants is None:
        field_types = set(o.field.dtype for o in ast_node.atoms(ResolvedFieldAccess))
        double = create_type('double')
        use_float_constants = double not in field_types
38

Martin Bauer's avatar
Martin Bauer committed
39
    vector_is = selected_instruction_set['double']
Martin Bauer's avatar
Martin Bauer committed
40
41
42
    printer = CBackend(constants_as_floats=use_float_constants, signature_only=signature_only,
                       vector_instruction_set=vector_is)
    return printer(ast_node)
43
44


Martin Bauer's avatar
Martin Bauer committed
45
46
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
47
48
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
49
50
51
52
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    elif isinstance(ast_node, SympyAssignment):
        if type(get_type_of_expression(ast_node.rhs)) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
53
            headers.update(selected_instruction_set['double']['headers'])
54

Martin Bauer's avatar
Martin Bauer committed
55
    for a in ast_node.args:
56
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
57
            headers.update(get_headers(a))
58
59

    return headers
60
61


62
63
64
65
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


class CustomCppCode(Node):
Martin Bauer's avatar
Martin Bauer committed
66
67
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
        super(CustomCppCode, self).__init__(parent=parent)
68
        self._code = "\n" + code
Martin Bauer's avatar
Martin Bauer committed
69
70
        self._symbolsRead = set(symbols_read)
        self._symbolsDefined = set(symbols_defined)
71
        self.headers = []
72
73
74
75
76
77
78
79
80
81

    @property
    def code(self):
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
82
    def symbols_defined(self):
83
84
85
        return self._symbolsDefined

    @property
Martin Bauer's avatar
Martin Bauer committed
86
87
    def undefined_symbols(self):
        return self.symbols_defined - self._symbolsRead
88
89
90


class PrintNode(CustomCppCode):
Martin Bauer's avatar
Martin Bauer committed
91
92
93
94
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
95
        self.headers.append("<iostream>")
96
97
98
99


# ------------------------------------------- Printer ------------------------------------------------------------------

100

Martin Bauer's avatar
Martin Bauer committed
101
102
# noinspection PyPep8Naming
class CBackend:
103

Martin Bauer's avatar
Martin Bauer committed
104
105
106
    def __init__(self, constants_as_floats=False, sympy_printer=None,
                 signature_only=False, vector_instruction_set=None):
        if sympy_printer is None:
Martin Bauer's avatar
Martin Bauer committed
107
            self.sympy_printer = CustomSympyPrinter(constants_as_floats)
Martin Bauer's avatar
Martin Bauer committed
108
            if vector_instruction_set is not None:
Martin Bauer's avatar
Martin Bauer committed
109
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, constants_as_floats)
110
            else:
Martin Bauer's avatar
Martin Bauer committed
111
                self.sympy_printer = CustomSympyPrinter(constants_as_floats)
112
        else:
Martin Bauer's avatar
Martin Bauer committed
113
            self.sympy_printer = sympy_printer
114

Martin Bauer's avatar
Martin Bauer committed
115
        self._vectorInstructionSet = vector_instruction_set
116
        self._indent = "   "
Martin Bauer's avatar
Martin Bauer committed
117
        self._signatureOnly = signature_only
118
119

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
120
121
        prev_is = VectorType.instruction_set
        VectorType.instruction_set = self._vectorInstructionSet
122
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
123
        VectorType.instruction_set = prev_is
124
        return result
125
126
127

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
128
129
130
131
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
132
133

    def _print_KernelFunction(self, node):
Martin Bauer's avatar
Martin Bauer committed
134
        function_arguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
Martin Bauer's avatar
Martin Bauer committed
135
        func_declaration = "FUNC_PREFIX void %s(%s)" % (node.function_name, ", ".join(function_arguments))
136
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
137
            return func_declaration
138

139
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
140
        return func_declaration + "\n" + body
141
142

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
143
144
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
145
146

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
147
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
148
149

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
150
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
151
152
153
154
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
155

Martin Bauer's avatar
Martin Bauer committed
156
        prefix = "\n".join(node.prefix_lines)
157
158
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
159
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
160
161

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
162
163
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
Martin Bauer's avatar
Martin Bauer committed
164
            return "%s %s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
165
        else:
Martin Bauer's avatar
Martin Bauer committed
166
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
167
            if type(lhs_type) is VectorType and node.lhs.func == cast_func:
Martin Bauer's avatar
Martin Bauer committed
168
169
                return self._vectorInstructionSet['storeU'].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
                                                                   self.sympy_printer.doprint(node.rhs)) + ';'
170
            else:
Martin Bauer's avatar
Martin Bauer committed
171
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
172
173

    def _print_TemporaryMemoryAllocation(self, node):
Martin Bauer's avatar
Martin Bauer committed
174
175
        return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympy_printer.doprint(node.symbol.name),
                                        node.symbol.dtype.base_type, self.sympy_printer.doprint(node.size))
176
177

    def _print_TemporaryMemoryFree(self, node):
Martin Bauer's avatar
Martin Bauer committed
178
        return "delete [] %s;" % (self.sympy_printer.doprint(node.symbol.name),)
179

Martin Bauer's avatar
Martin Bauer committed
180
181
    @staticmethod
    def _print_CustomCppCode(node):
182
        return node.code
183

184
    def _print_Conditional(self, node):
Martin Bauer's avatar
Martin Bauer committed
185
186
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
187
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
188
189
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
190
            result += "else " + false_block
191
192
        return result

193
194
195
196

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
197
# noinspection PyPep8Naming
198
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
199

Martin Bauer's avatar
Martin Bauer committed
200
201
    def __init__(self, constants_as_floats=False):
        self._constantsAsFloats = constants_as_floats
Martin Bauer's avatar
Martin Bauer committed
202
203
        super(CustomSympyPrinter, self).__init__()

204
205
206
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
207
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
208
209
210
211
212
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
213
214
215
216
        res = str(expr.evalf().num)
        if self._constantsAsFloats:
            res += "f"
        return res
217
218
219
220
221
222
223
224

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
225
226
227
228
229
230
231
        return result.replace("\n", "")

    def _print_Float(self, expr):
        res = str(expr)
        if self._constantsAsFloats:
            res += "f"
        return res
232

233
    def _print_Function(self, expr):
Martin Bauer's avatar
Martin Bauer committed
234
        function_map = {
Martin Bauer's avatar
Martin Bauer committed
235
236
237
238
239
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
240
        }
Martin Bauer's avatar
Martin Bauer committed
241
        if expr.func == cast_func:
Martin Bauer's avatar
Martin Bauer committed
242
243
244
245
            arg, data_type = expr.args
            return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
        elif expr.func in function_map:
            return "(%s %s %s)" % (self._print(expr.args[0]), function_map[expr.func], self._print(expr.args[1]))
246
        else:
247
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
248

249

Martin Bauer's avatar
Martin Bauer committed
250
# noinspection PyPep8Naming
251
252
253
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

Martin Bauer's avatar
Martin Bauer committed
254
255
    def __init__(self, instruction_set, constants_as_floats=False):
        super(VectorizedCustomSympyPrinter, self).__init__(constants_as_floats)
Martin Bauer's avatar
Martin Bauer committed
256
        self.instruction_set = instruction_set
257

Martin Bauer's avatar
Martin Bauer committed
258
259
260
261
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
262
        else:
Martin Bauer's avatar
Martin Bauer committed
263
            assert self.instruction_set['width'] == expr_type.width
264
265
            return None

266
    def _print_Function(self, expr):
Martin Bauer's avatar
Martin Bauer committed
267
        if expr.func == cast_func:
Martin Bauer's avatar
Martin Bauer committed
268
269
            arg, data_type = expr.args
            if type(data_type) is VectorType:
270
                if type(arg) is ResolvedFieldAccess:
Martin Bauer's avatar
Martin Bauer committed
271
                    return self.instruction_set['loadU'].format("& " + self._print(arg))
272
                else:
Martin Bauer's avatar
Martin Bauer committed
273
                    return self.instruction_set['makeVec'].format(self._print(arg))
274
275
276

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

277
278
279
280
281
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
282
283
284
285
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
286
            result = self.instruction_set['&'].format(result, item)
287
288
289
290
291
292
293
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
294
295
296
297
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
298
            result = self.instruction_set['|'].format(result, item)
299
300
        return result

301
    def _print_Add(self, expr, order=None):
302
303
304
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
305
306
307
308

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
309
                sign, t = self._print_Mul(term, inside_add=True)
310
311
312
313
314
315
316
317
318
319
320
321
322
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
323
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
324
325
326
            processed = func.format(processed, summand.term)
        return processed

327
    def _print_Pow(self, expr):
328
329
330
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
331
332
333
334
335

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
        else:
            if expr.exp == -1:
Martin Bauer's avatar
Martin Bauer committed
336
337
                one = self.instruction_set['makeVec'].format(1.0)
                return self.instruction_set['/'].format(one, self._print(expr.base))
338
            elif expr.exp == 0.5:
Martin Bauer's avatar
Martin Bauer committed
339
                return self.instruction_set['sqrt'].format(self._print(expr.base))
340
341
342
            else:
                raise ValueError("Generic exponential not supported")

Martin Bauer's avatar
Martin Bauer committed
343
344
345
346
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

347
348
349
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]
Martin Bauer's avatar
Martin Bauer committed
372
        # a = a or [cast_func(S.One, VectorType(create_type_from_string("double"), expr_type.width))]
373
374
375
376
377
378

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
379
            result = self.instruction_set['*'].format(result, item)
380
381
382
383

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
384
385
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
386

Martin Bauer's avatar
Martin Bauer committed
387
        if inside_add:
388
389
390
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
391
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
392
393
394
            else:
                return result

395
    def _print_Relational(self, expr):
396
397
398
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
399
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
400
401

    def _print_Equality(self, expr):
402
403
404
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
405
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
406
407

    def _print_Piecewise(self, expr):
408
409
410
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
411

Martin Bauer's avatar
Martin Bauer committed
412
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
413
414
415
416
417
418
419
420
421
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
422
        for true_expr, condition in reversed(expr.args[:-1]):
Martin Bauer's avatar
Martin Bauer committed
423
            # noinspection SpellCheckingInspection
Martin Bauer's avatar
Martin Bauer committed
424
            result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
425
        return result