cbackend.py 24.3 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
from collections import namedtuple
2
from typing import Set
3
4
5
6

import jinja2
import sympy as sp
from sympy.core import S
7
from sympy.printing.ccode import C89CodePrinter
8

9
10
11
from pystencils.astnodes import (DestructuringBindingsForFieldClass,
                                 KernelFunction, Node)
from pystencils.cpu.vectorization import vec_all, vec_any
12
from pystencils.data_types import (PointerType, VectorType, address_of,
13
                                   cast_func, create_type,
14
                                   get_type_of_expression,
15
16
17
18
19
20
21
                                   reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
                                           fast_sqrt)
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
                                          bitwise_and, bitwise_or, bitwise_xor,
                                          int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
22

Martin Bauer's avatar
Martin Bauer committed
23
24
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
25
26
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
27

28

29
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
30

Martin Bauer's avatar
Martin Bauer committed
31

32
33
KERNCRAFT_NO_TERNARY_MODE = False

Martin Bauer's avatar
Fixes    
Martin Bauer committed
34

35
36
37
38
39
class UnsupportedCDialect(Exception):
    def __init__(self):
        super(UnsupportedCDialect, self).__init__()


40
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c') -> str:
Martin Bauer's avatar
Martin Bauer committed
41
42
43
44
45
46
47
48
49
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:
50
        dialect: 'c' or 'cuda'
Martin Bauer's avatar
Martin Bauer committed
51
52
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
53
    """
54
55
56
57
58
59
    global_declarations = get_global_declarations(ast_node)
    for d in global_declarations:
        if hasattr(ast_node, "global_variables"):
            ast_node.global_variables.update(d.symbols_defined)
        else:
            ast_node.global_variables = d.symbols_defined
60
61
62
63
64
65
66
67
68

    if dialect == 'c':
        printer = CBackend(signature_only=signature_only,
                           vector_instruction_set=ast_node.instruction_set)
    elif dialect == 'cuda':
        from pystencils.backends.cuda_backend import CudaBackend
        printer = CudaBackend(signature_only=signature_only)
    else:
        raise UnsupportedCDialect
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    code = printer(ast_node)
    if not signature_only and isinstance(ast_node, KernelFunction):
        code = "\n" + code
        for declaration in global_declarations:
            code = printer(declaration) + "\n" + code

    return code


def get_global_declarations(ast):
    global_declarations = []

    def visit_node(sub_ast):
        if hasattr(sub_ast, "required_global_declarations"):
            nonlocal global_declarations
            global_declarations += sub_ast.required_global_declarations

        if hasattr(sub_ast, "args"):
            for node in sub_ast.args:
                visit_node(node)

    visit_node(ast)

    return set(global_declarations)
93
94


Martin Bauer's avatar
Martin Bauer committed
95
96
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
97
98
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
99
100
101
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
102
103
104
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
105
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
106
            headers.update(get_headers(a))
107
108

    return headers
109
110


111
112
113
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


114
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
115
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
116
        super(CustomCodeNode, self).__init__(parent=parent)
117
        self._code = "\n" + code
118
119
        self._symbols_read = set(symbols_read)
        self._symbols_defined = set(symbols_defined)
120
        self.headers = []
121

122
    def get_code(self, dialect, vector_instruction_set):
123
124
125
126
127
128
129
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
130
    def symbols_defined(self):
131
        return self._symbols_defined
132
133

    @property
Martin Bauer's avatar
Martin Bauer committed
134
    def undefined_symbols(self):
135
        return self._symbols_read - self._symbols_defined
136
137


138
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
139
140
141
142
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
143
        self.headers.append("<iostream>")
144
145
146
147


# ------------------------------------------- Printer ------------------------------------------------------------------

148

Martin Bauer's avatar
Martin Bauer committed
149
150
# noinspection PyPep8Naming
class CBackend:
151

Martin Bauer's avatar
Martin Bauer committed
152
    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
Martin Bauer's avatar
Martin Bauer committed
153
154
        if sympy_printer is None:
            if vector_instruction_set is not None:
155
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
156
            else:
157
                self.sympy_printer = CustomSympyPrinter()
158
        else:
Martin Bauer's avatar
Martin Bauer committed
159
            self.sympy_printer = sympy_printer
160

161
        self._vector_instruction_set = vector_instruction_set
162
        self._indent = "   "
163
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
164
        self._signatureOnly = signature_only
165
166

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
167
        prev_is = VectorType.instruction_set
168
        VectorType.instruction_set = self._vector_instruction_set
169
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
170
        VectorType.instruction_set = prev_is
171
        return result
172
173
174

    def _print(self, node):
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
175
176
177
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
178
        raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
179
180

    def _print_KernelFunction(self, node):
181
        function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
182
        launch_bounds = ""
183
        if self.__class__ == 'cuda':
184
185
186
187
188
            max_threads = node.indexing.max_threads_per_block()
            if max_threads:
                launch_bounds = "__launch_bounds__({}) ".format(max_threads)
        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
                                                          ", ".join(function_arguments))
189
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
190
            return func_declaration
191

192
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
193
        return func_declaration + "\n" + body
194
195

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
196
197
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
198
199

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
200
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
201
202

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
203
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
204
205
206
207
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
208

Martin Bauer's avatar
Martin Bauer committed
209
        prefix = "\n".join(node.prefix_lines)
210
211
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
212
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
213
214

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
215
216
        if node.is_declaration:
            data_type = "const " + str(node.lhs.dtype) + " " if node.is_const else str(node.lhs.dtype) + " "
217
218
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
219
        else:
Martin Bauer's avatar
Martin Bauer committed
220
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
221
222
223
224
225
226
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

227
228
229
230
231
232
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

233
234
                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
                                                                  self.sympy_printer.doprint(rhs)) + ';'
235
            else:
Martin Bauer's avatar
Martin Bauer committed
236
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
237
238

    def _print_TemporaryMemoryAllocation(self, node):
239
        align = 64
Martin Bauer's avatar
Martin Bauer committed
240
241
242
243
244
245
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
246
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
247
248
                           offset=int(node.offset(align)),
                           align=align)
249
250

    def _print_TemporaryMemoryFree(self, node):
251
        align = 64
Martin Bauer's avatar
Martin Bauer committed
252
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
253

Martin Bauer's avatar
Martin Bauer committed
254
    def _print_SkipIteration(self, _):
255
        return "continue;"
Martin Bauer's avatar
Martin Bauer committed
256

257
258
    def _print_CustomCodeNode(self, node):
        return node.get_code(self._dialect, self._vector_instruction_set)
259

260
    def _print_Conditional(self, node):
261
262
263
        cond_type = get_type_of_expression(node.condition_expr)
        if isinstance(cond_type, VectorType):
            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
Martin Bauer's avatar
Martin Bauer committed
264
265
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
266
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
267
268
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
269
            result += "else " + false_block
270
271
        return result

272
273
274
275
276
277
278
279
280
281
    def _print_DestructuringBindingsForFieldClass(self, node: Node):
        # Define all undefined symbols
        undefined_field_symbols = node.symbols_defined
        destructuring_bindings = ["%s = %s.%s%s;" %
                                  (u.name,
                                   u.field_name if hasattr(u, 'field_name') else u.field_names[0],
                                   DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
                                   "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
                                  for u in undefined_field_symbols
                                  ]
282
        destructuring_bindings.sort()  # only for code aesthetics
283
284
285
286
287
288
289
290
291
292
293
294
295
        template = jinja2.Template(
            """{
   {% for binding in bindings -%}
   {{ binding | indent(3) }}
   {% endfor -%}
   {{ block | indent(3) }}
}

""")
        code = template.render(bindings=destructuring_bindings,
                               block=self._print(node.body))
        return code

296
297
298
299

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
300
# noinspection PyPep8Naming
301
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
302

303
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
304
        super(CustomSympyPrinter, self).__init__()
305
        self._float_type = create_type("float32")
306
307
308
309
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
310

311
312
313
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
314
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
315
316
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
317
318
319
320
321
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
322
323
        res = str(expr.evalf().num)
        return res
324
325
326
327
328
329
330
331

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
332
333
        return result.replace("\n", "")

334
    def _print_Function(self, expr):
335
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
336
337
338
339
340
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
341
        }
Martin Bauer's avatar
Martin Bauer committed
342
343
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
344
345
346
        if isinstance(expr, reinterpret_cast_func):
            arg, data_type = expr.args
            return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
347
348
349
        elif isinstance(expr, address_of):
            assert len(expr.args) == 1, "address_of must only have one argument"
            return "&(%s)" % self._print(expr.args[0])
350
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
351
            arg, data_type = expr.args
352
353
354
            if isinstance(arg, sp.Number):
                return self._typed_number(arg, data_type)
            else:
355
356
                return "((%s)(%s))" % (data_type, self._print(arg))
        elif isinstance(expr, fast_division):
357
            return "({})".format(self._print(expr.args[0] / expr.args[1]))
358
        elif isinstance(expr, fast_sqrt):
359
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
360
361
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            return self._print(expr.args[0])
362
        elif isinstance(expr, fast_inv_sqrt):
363
            return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
364
365
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
366
367
368
369
        elif expr.func == int_power_of_2:
            return "(1 << (%s))" % (self._print(expr.args[0]))
        elif expr.func == int_div:
            return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
370
        else:
371
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
372

373
374
    def _typed_number(self, number, dtype):
        res = self._print(number)
375
        if dtype.is_float():
376
377
378
379
380
381
382
383
            if dtype == self._float_type:
                if '.' not in res:
                    res += ".0f"
                else:
                    res += "f"
            return res
        else:
            return res
384

385
386
387
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

388

Martin Bauer's avatar
Martin Bauer committed
389
# noinspection PyPep8Naming
390
391
392
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

393
394
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
395
        self.instruction_set = instruction_set
396

Martin Bauer's avatar
Martin Bauer committed
397
398
399
400
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
401
        else:
Martin Bauer's avatar
Martin Bauer committed
402
            assert self.instruction_set['width'] == expr_type.width
403
404
            return None

405
    def _print_Function(self, expr):
406
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
407
408
409
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
410
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
411
412
            arg, data_type = expr.args
            if type(data_type) is VectorType:
Martin Bauer's avatar
Martin Bauer committed
413
                return self.instruction_set['makeVec'].format(self._print(arg))
414
        elif expr.func == fast_division:
415
416
            result = self._scalarFallback('_print_Function', expr)
            if not result:
417
418
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
            return result
419
420
421
        elif expr.func == fast_sqrt:
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
        elif expr.func == fast_inv_sqrt:
422
423
424
425
426
427
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if self.instruction_set['rsqrt']:
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
                else:
                    return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
428
429
430
431
432
433
434
435
436
437
438
439
440
        elif isinstance(expr, vec_any):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['any'].format(self._print(expr.args[0]))
        elif isinstance(expr, vec_all):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['all'].format(self._print(expr.args[0]))

441
442
        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

443
444
445
446
447
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
448
449
450
451
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
452
            result = self.instruction_set['&'].format(result, item)
453
454
455
456
457
458
459
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
460
461
462
463
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
464
            result = self.instruction_set['|'].format(result, item)
465
466
        return result

467
    def _print_Add(self, expr, order=None):
468
469
470
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
471
472
473
474

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
475
                sign, t = self._print_Mul(term, inside_add=True)
476
477
478
479
480
481
482
483
484
485
486
487
488
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
489
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
490
491
492
            processed = func.format(processed, summand.term)
        return processed

493
    def _print_Pow(self, expr):
494
495
496
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
497

498
499
        one = self.instruction_set['makeVec'].format(1.0)

500
501
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
502
503
504
505
506
        elif expr.exp == -1:
            one = self.instruction_set['makeVec'].format(1.0)
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
507
508
509
        elif expr.exp == -0.5:
            root = self.instruction_set['sqrt'].format(self._print(expr.base))
            return self.instruction_set['/'].format(one, root)
510
511
512
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
513
        else:
514
            raise ValueError("Generic exponential not supported: " + str(expr))
515

Martin Bauer's avatar
Martin Bauer committed
516
517
518
519
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

520
521
522
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
551
            result = self.instruction_set['*'].format(result, item)
552
553
554
555

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
556
557
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
558

Martin Bauer's avatar
Martin Bauer committed
559
        if inside_add:
560
561
562
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
563
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
564
565
566
            else:
                return result

567
    def _print_Relational(self, expr):
568
569
570
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
571
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
572
573

    def _print_Equality(self, expr):
574
575
576
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
577
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
578
579

    def _print_Piecewise(self, expr):
580
581
582
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
583

Martin Bauer's avatar
Martin Bauer committed
584
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
585
586
587
588
589
590
591
592
593
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
594
        for true_expr, condition in reversed(expr.args[:-1]):
595
            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
596
597
598
599
600
                if not KERNCRAFT_NO_TERNARY_MODE:
                    result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
                                                           result)
                else:
                    print("Warning - skipping ternary op")
601
602
603
            else:
                # noinspection SpellCheckingInspection
                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
604
        return result