cbackend.py 22.4 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2
3
from collections import namedtuple
from sympy.core import S
4
from typing import Set
5
from sympy.printing.ccode import C89CodePrinter
6

7
from pystencils.cpu.vectorization import vec_any, vec_all
8
9
10
11
12
from pystencils.data_types import (PointerType, VectorType, address_of,
                                   cast_func, create_type, reinterpret_cast_func,
                                   get_type_of_expression,
                                   vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
13

Martin Bauer's avatar
Martin Bauer committed
14
15
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
16
17
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
18

Martin Bauer's avatar
Martin Bauer committed
19
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
20
    bitwise_or, modulo_ceil, int_div, int_power_of_2
21
from pystencils.astnodes import Node, KernelFunction
22

23
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
24

Martin Bauer's avatar
Martin Bauer committed
25

26
27
KERNCRAFT_NO_TERNARY_MODE = False

Martin Bauer's avatar
Fixes    
Martin Bauer committed
28

29
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
Martin Bauer's avatar
Martin Bauer committed
30
31
32
33
34
35
36
37
38
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:
39
        dialect: 'c' or 'cuda'
Martin Bauer's avatar
Martin Bauer committed
40
41
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
42
    """
43
    printer = CBackend(signature_only=signature_only,
44
45
                       vector_instruction_set=ast_node.instruction_set,
                       dialect=dialect)
Martin Bauer's avatar
Martin Bauer committed
46
    return printer(ast_node)
47
48


Martin Bauer's avatar
Martin Bauer committed
49
50
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
51
52
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
53
54
55
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
56
57
58
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
59
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
60
            headers.update(get_headers(a))
61
62

    return headers
63
64


65
66
67
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


68
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
69
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
70
        super(CustomCodeNode, self).__init__(parent=parent)
71
        self._code = "\n" + code
72
73
        self._symbols_read = set(symbols_read)
        self._symbols_defined = set(symbols_defined)
74
        self.headers = []
75

76
    def get_code(self, dialect, vector_instruction_set):
77
78
79
80
81
82
83
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
84
    def symbols_defined(self):
85
        return self._symbols_defined
86
87

    @property
Martin Bauer's avatar
Martin Bauer committed
88
    def undefined_symbols(self):
89
        return self._symbols_read - self._symbols_defined
90
91


92
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
93
94
95
96
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
97
        self.headers.append("<iostream>")
98
99
100
101


# ------------------------------------------- Printer ------------------------------------------------------------------

102

Martin Bauer's avatar
Martin Bauer committed
103
104
# noinspection PyPep8Naming
class CBackend:
105

Martin Bauer's avatar
Martin Bauer committed
106
    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
Martin Bauer's avatar
Martin Bauer committed
107
108
        if sympy_printer is None:
            if vector_instruction_set is not None:
109
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set, dialect)
110
            else:
111
                self.sympy_printer = CustomSympyPrinter(dialect)
112
        else:
Martin Bauer's avatar
Martin Bauer committed
113
            self.sympy_printer = sympy_printer
114

115
        self._vector_instruction_set = vector_instruction_set
116
        self._indent = "   "
117
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
118
        self._signatureOnly = signature_only
119
120

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
121
        prev_is = VectorType.instruction_set
122
        VectorType.instruction_set = self._vector_instruction_set
123
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
124
        VectorType.instruction_set = prev_is
125
        return result
126
127
128

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
129
130
131
132
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
        raise NotImplementedError("CBackend does not support node of type " + str(type(node)))
133
134

    def _print_KernelFunction(self, node):
135
        function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
136
137
138
139
140
141
142
        launch_bounds = ""
        if self._dialect == 'cuda':
            max_threads = node.indexing.max_threads_per_block()
            if max_threads:
                launch_bounds = "__launch_bounds__({}) ".format(max_threads)
        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
                                                          ", ".join(function_arguments))
143
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
144
            return func_declaration
145

146
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
147
        return func_declaration + "\n" + body
148
149

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
150
151
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
152
153

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
154
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
155
156

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
157
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
158
159
160
161
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
162

Martin Bauer's avatar
Martin Bauer committed
163
        prefix = "\n".join(node.prefix_lines)
164
165
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
166
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
167
168

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
169
170
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
171
172
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
173
        else:
Martin Bauer's avatar
Martin Bauer committed
174
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
175
176
177
178
179
180
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

181
182
183
184
185
186
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

187
188
                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
                                                                  self.sympy_printer.doprint(rhs)) + ';'
189
            else:
Martin Bauer's avatar
Martin Bauer committed
190
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
191
192

    def _print_TemporaryMemoryAllocation(self, node):
193
        align = 64
Martin Bauer's avatar
Martin Bauer committed
194
195
196
197
198
199
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
200
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
201
202
                           offset=int(node.offset(align)),
                           align=align)
203
204

    def _print_TemporaryMemoryFree(self, node):
205
        align = 64
Martin Bauer's avatar
Martin Bauer committed
206
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
207

Martin Bauer's avatar
Martin Bauer committed
208
209
210
211
212
213
    def _print_SkipIteration(self, _):
        if self._dialect == 'cuda':
            return "return;"
        else:
            return "continue;"

214
215
    def _print_CustomCodeNode(self, node):
        return node.get_code(self._dialect, self._vector_instruction_set)
216

217
    def _print_Conditional(self, node):
218
219
220
        cond_type = get_type_of_expression(node.condition_expr)
        if isinstance(cond_type, VectorType):
            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
Martin Bauer's avatar
Martin Bauer committed
221
222
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
223
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
224
225
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
226
            result += "else " + false_block
227
228
        return result

229
230
231
232

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
233
# noinspection PyPep8Naming
234
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
235

236
    def __init__(self, dialect):
Martin Bauer's avatar
Martin Bauer committed
237
        super(CustomSympyPrinter, self).__init__()
238
        self._float_type = create_type("float32")
239
        self._dialect = dialect
240
241
242
243
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
244

245
246
247
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
248
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
249
250
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
251
252
253
254
255
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
256
257
        res = str(expr.evalf().num)
        return res
258
259
260
261
262
263
264
265

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
266
267
        return result.replace("\n", "")

268
    def _print_Function(self, expr):
269
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
270
271
272
273
274
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
275
        }
Martin Bauer's avatar
Martin Bauer committed
276
277
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
278
279
280
        if isinstance(expr, reinterpret_cast_func):
            arg, data_type = expr.args
            return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
281
282
283
        elif isinstance(expr, address_of):
            assert len(expr.args) == 1, "address_of must only have one argument"
            return "&(%s)" % self._print(expr.args[0])
284
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
285
            arg, data_type = expr.args
286
287
288
            if isinstance(arg, sp.Number):
                return self._typed_number(arg, data_type)
            else:
289
290
291
292
293
294
295
296
297
298
299
                return "((%s)(%s))" % (data_type, self._print(arg))
        elif isinstance(expr, fast_division):
            if self._dialect == "cuda":
                return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
            else:
                return "({})".format(self._print(expr.args[0] / expr.args[1]))
        elif isinstance(expr, fast_sqrt):
            if self._dialect == "cuda":
                return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
            else:
                return "({})".format(self._print(sp.sqrt(expr.args[0])))
300
301
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            return self._print(expr.args[0])
302
303
304
305
306
        elif isinstance(expr, fast_inv_sqrt):
            if self._dialect == "cuda":
                return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
            else:
                return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
307
308
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
309
310
311
312
        elif expr.func == int_power_of_2:
            return "(1 << (%s))" % (self._print(expr.args[0]))
        elif expr.func == int_div:
            return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
313
        else:
314
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
315

316
317
    def _typed_number(self, number, dtype):
        res = self._print(number)
318
        if dtype.is_float():
319
320
321
322
323
324
325
326
            if dtype == self._float_type:
                if '.' not in res:
                    res += ".0f"
                else:
                    res += "f"
            return res
        else:
            return res
327

328
329
330
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

331

Martin Bauer's avatar
Martin Bauer committed
332
# noinspection PyPep8Naming
333
334
335
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

336
337
    def __init__(self, instruction_set, dialect):
        super(VectorizedCustomSympyPrinter, self).__init__(dialect=dialect)
Martin Bauer's avatar
Martin Bauer committed
338
        self.instruction_set = instruction_set
339

Martin Bauer's avatar
Martin Bauer committed
340
341
342
343
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
344
        else:
Martin Bauer's avatar
Martin Bauer committed
345
            assert self.instruction_set['width'] == expr_type.width
346
347
            return None

348
    def _print_Function(self, expr):
349
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
350
351
352
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
353
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
354
355
            arg, data_type = expr.args
            if type(data_type) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
356
                return self.instruction_set['makeVec'].format(self._print(arg))
357
        elif expr.func == fast_division:
358
359
            result = self._scalarFallback('_print_Function', expr)
            if not result:
360
361
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
            return result
362
363
364
        elif expr.func == fast_sqrt:
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
        elif expr.func == fast_inv_sqrt:
365
366
367
368
369
370
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if self.instruction_set['rsqrt']:
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
                else:
                    return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
371
372
373
374
375
376
377
378
379
380
381
382
383
        elif isinstance(expr, vec_any):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['any'].format(self._print(expr.args[0]))
        elif isinstance(expr, vec_all):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['all'].format(self._print(expr.args[0]))

384
385
        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

386
387
388
389
390
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
391
392
393
394
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
395
            result = self.instruction_set['&'].format(result, item)
396
397
398
399
400
401
402
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
403
404
405
406
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
407
            result = self.instruction_set['|'].format(result, item)
408
409
        return result

410
    def _print_Add(self, expr, order=None):
411
412
413
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
414
415
416
417

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
418
                sign, t = self._print_Mul(term, inside_add=True)
419
420
421
422
423
424
425
426
427
428
429
430
431
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
432
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
433
434
435
            processed = func.format(processed, summand.term)
        return processed

436
    def _print_Pow(self, expr):
437
438
439
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
440

441
442
        one = self.instruction_set['makeVec'].format(1.0)

443
444
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
445
446
447
448
449
        elif expr.exp == -1:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
450
451
452
        elif expr.exp == -0.5:
            root = self.instruction_set['sqrt'].format(self._print(expr.base))
            return self.instruction_set['/'].format(one, root)
453
454
455
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
456
        else:
457
            raise ValueError("Generic exponential not supported: " + str(expr))
458

Martin Bauer's avatar
Martin Bauer committed
459
460
461
462
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

463
464
465
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
494
            result = self.instruction_set['*'].format(result, item)
495
496
497
498

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
499
500
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
501

Martin Bauer's avatar
Martin Bauer committed
502
        if inside_add:
503
504
505
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
506
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
507
508
509
            else:
                return result

510
    def _print_Relational(self, expr):
511
512
513
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
514
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
515
516

    def _print_Equality(self, expr):
517
518
519
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
520
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
521
522

    def _print_Piecewise(self, expr):
523
524
525
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
526

Martin Bauer's avatar
Martin Bauer committed
527
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
528
529
530
531
532
533
534
535
536
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
537
        for true_expr, condition in reversed(expr.args[:-1]):
538
            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
539
540
541
542
543
                if not KERNCRAFT_NO_TERNARY_MODE:
                    result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
                                                           result)
                else:
                    print("Warning - skipping ternary op")
544
545
546
            else:
                # noinspection SpellCheckingInspection
                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
547
        return result