cbackend.py 26.8 KB
Newer Older
Martin Bauer's avatar
Martin Bauer committed
1
from collections import namedtuple
2
from typing import Set
3

4
import numpy as np
5 6
import sympy as sp
from sympy.core import S
7
from sympy.printing.ccode import C89CodePrinter
8
from pystencils.astnodes import KernelFunction, Node
9
from pystencils.cpu.vectorization import vec_all, vec_any
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.data_types import (
11 12
    PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
    reinterpret_cast_func, vector_memory_access)
Martin Bauer's avatar
Martin Bauer committed
13 14
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
15 16
    bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
    int_div, int_power_of_2, modulo_ceil)
17

Martin Bauer's avatar
Martin Bauer committed
18 19
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
20 21
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
22

23
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
24

25 26
KERNCRAFT_NO_TERNARY_MODE = False

Martin Bauer's avatar
Fixes  
Martin Bauer committed
27

28 29 30 31 32
def generate_c(ast_node: Node,
               signature_only: bool = False,
               dialect='c',
               custom_backend=None,
               with_globals=True) -> str:
Martin Bauer's avatar
Martin Bauer committed
33 34 35 36 37 38 39 40 41
    """Prints an abstract syntax tree node as C or CUDA code.

    This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
    in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different create_kernel
    functions.

    Args:
        ast_node:
        signature_only:
42
        dialect: 'c' or 'cuda'
Martin Bauer's avatar
Martin Bauer committed
43 44
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
45
    """
46 47 48 49 50 51
    global_declarations = get_global_declarations(ast_node)
    for d in global_declarations:
        if hasattr(ast_node, "global_variables"):
            ast_node.global_variables.update(d.symbols_defined)
        else:
            ast_node.global_variables = d.symbols_defined
52 53 54
    if custom_backend:
        printer = custom_backend
    elif dialect == 'c':
55 56 57 58
        try:
            instruction_set = ast_node.instruction_set
        except Exception:
            instruction_set = None
59
        printer = CBackend(signature_only=signature_only,
60
                           vector_instruction_set=instruction_set)
61 62 63
    elif dialect == 'cuda':
        from pystencils.backends.cuda_backend import CudaBackend
        printer = CudaBackend(signature_only=signature_only)
Stephan Seitz's avatar
Stephan Seitz committed
64
    elif dialect == 'opencl':
65 66
        from pystencils.backends.opencl_backend import OpenClBackend
        printer = OpenClBackend(signature_only=signature_only)
67
    else:
Martin Bauer's avatar
Martin Bauer committed
68
        raise ValueError("Unknown dialect: " + str(dialect))
69 70
    code = printer(ast_node)
    if not signature_only and isinstance(ast_node, KernelFunction):
71 72 73 74
        if with_globals and global_declarations:
            code = "\n" + code
            for declaration in global_declarations:
                code = printer(declaration) + "\n" + code
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

    return code


def get_global_declarations(ast):
    global_declarations = []

    def visit_node(sub_ast):
        if hasattr(sub_ast, "required_global_declarations"):
            nonlocal global_declarations
            global_declarations += sub_ast.required_global_declarations

        if hasattr(sub_ast, "args"):
            for node in sub_ast.args:
                visit_node(node)

    visit_node(ast)

93
    return sorted(set(global_declarations), key=lambda x: str(x))
94 95


Martin Bauer's avatar
Martin Bauer committed
96 97
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
98 99
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
100 101 102
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
103 104 105
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
106
        if isinstance(a, Node):
Martin Bauer's avatar
Martin Bauer committed
107
            headers.update(get_headers(a))
108

109 110 111 112
    for g in get_global_declarations(ast_node):
        if isinstance(g, Node):
            headers.update(get_headers(g))

113
    return sorted(headers)
114 115


116 117 118
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


119
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
120
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
121
        super(CustomCodeNode, self).__init__(parent=parent)
122
        self._code = "\n" + code
123 124
        self._symbols_read = set(symbols_read)
        self._symbols_defined = set(symbols_defined)
125
        self.headers = []
126

127
    def get_code(self, dialect, vector_instruction_set):
128 129 130 131 132 133 134
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
135
    def symbols_defined(self):
136
        return self._symbols_defined
137 138

    @property
Martin Bauer's avatar
Martin Bauer committed
139
    def undefined_symbols(self):
140
        return self._symbols_read - self._symbols_defined
141

142 143 144 145 146 147
    def __eq___(self, other):
        return self._code == other._code

    def __hash__(self):
        return hash(self._code)

148

149
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
150 151 152 153
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
154
        self.headers.append("<iostream>")
155 156 157 158


# ------------------------------------------- Printer ------------------------------------------------------------------

159

Martin Bauer's avatar
Martin Bauer committed
160 161
# noinspection PyPep8Naming
class CBackend:
162

Martin Bauer's avatar
Martin Bauer committed
163
    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
Martin Bauer's avatar
Martin Bauer committed
164 165
        if sympy_printer is None:
            if vector_instruction_set is not None:
166
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
167
            else:
168
                self.sympy_printer = CustomSympyPrinter()
169
        else:
Martin Bauer's avatar
Martin Bauer committed
170
            self.sympy_printer = sympy_printer
171

172
        self._vector_instruction_set = vector_instruction_set
173
        self._indent = "   "
174
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
175
        self._signatureOnly = signature_only
176 177

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
178
        prev_is = VectorType.instruction_set
179
        VectorType.instruction_set = self._vector_instruction_set
180
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
181
        VectorType.instruction_set = prev_is
182
        return result
183 184

    def _print(self, node):
Stephan Seitz's avatar
Stephan Seitz committed
185 186
        if isinstance(node, str):
            return node
187
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
188 189 190
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
191
        raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__)
192

193 194 195
    def _print_Type(self, node):
        return str(node)

196
    def _print_KernelFunction(self, node):
197
        function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
198
        launch_bounds = ""
199
        if self._dialect == 'cuda':
200 201 202 203 204
            max_threads = node.indexing.max_threads_per_block()
            if max_threads:
                launch_bounds = "__launch_bounds__({}) ".format(max_threads)
        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
                                                          ", ".join(function_arguments))
205
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
206
            return func_declaration
207

208
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
209
        return func_declaration + "\n" + body
210 211

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
212 213
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
214 215

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
216
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
217 218

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
219
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
220 221 222 223
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
224

Martin Bauer's avatar
Martin Bauer committed
225
        prefix = "\n".join(node.prefix_lines)
226 227
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
228
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
229 230

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
231
        if node.is_declaration:
Stephan Seitz's avatar
Lint  
Stephan Seitz committed
232 233 234 235
            if node.is_const:
                prefix = 'const '
            else:
                prefix = ''
236
            data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
237 238
            return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
239
        else:
Martin Bauer's avatar
Martin Bauer committed
240
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
241
            printed_mask = ""
Martin Bauer's avatar
Martin Bauer committed
242
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
Martin Bauer's avatar
Martin Bauer committed
243
                arg, data_type, aligned, nontemporal, mask = node.lhs.args
Martin Bauer's avatar
Martin Bauer committed
244 245 246
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'
Martin Bauer's avatar
Martin Bauer committed
247
                if mask != True:  # NOQA
Martin Bauer's avatar
Martin Bauer committed
248 249 250 251
                    instr = 'maskStore' if aligned else 'maskStoreU'
                    printed_mask = self.sympy_printer.doprint(mask)
                    if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d':
                        printed_mask = "_mm256_castpd_si256({})".format(printed_mask)
Martin Bauer's avatar
Martin Bauer committed
252

253 254 255 256 257 258
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

259
                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
Martin Bauer's avatar
Martin Bauer committed
260 261
                                                                  self.sympy_printer.doprint(rhs),
                                                                  printed_mask) + ';'
262
            else:
Martin Bauer's avatar
Martin Bauer committed
263
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
264 265

    def _print_TemporaryMemoryAllocation(self, node):
266
        align = 64
Martin Bauer's avatar
Martin Bauer committed
267 268 269 270 271 272
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
273
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
274 275
                           offset=int(node.offset(align)),
                           align=align)
276 277

    def _print_TemporaryMemoryFree(self, node):
278
        align = 64
Martin Bauer's avatar
Martin Bauer committed
279
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
280

Martin Bauer's avatar
Martin Bauer committed
281
    def _print_SkipIteration(self, _):
282
        return "continue;"
Martin Bauer's avatar
Martin Bauer committed
283

284 285
    def _print_CustomCodeNode(self, node):
        return node.get_code(self._dialect, self._vector_instruction_set)
286

287 288 289 290 291 292
    def _print_SourceCodeComment(self, node):
        return "/* " + node.text + " */"

    def _print_EmptyLine(self, node):
        return ""

293
    def _print_Conditional(self, node):
294 295 296
        cond_type = get_type_of_expression(node.condition_expr)
        if isinstance(cond_type, VectorType):
            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
Martin Bauer's avatar
Martin Bauer committed
297 298
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
299
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
300 301
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
302
            result += "else " + false_block
303 304
        return result

305 306 307 308

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
309
# noinspection PyPep8Naming
310
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
311

312
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
313
        super(CustomSympyPrinter, self).__init__()
314
        self._float_type = create_type("float32")
315 316 317 318
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
319

320 321
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
Martin Bauer's avatar
Martin Bauer committed
322 323 324
        if not expr.free_symbols:
            return self._typed_number(expr.evalf(), get_type_of_expression(expr))

325
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
326
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
327 328
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
329 330 331 332 333
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
334 335
        res = str(expr.evalf().num)
        return res
336 337 338 339 340 341 342 343

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
344 345
        return result.replace("\n", "")

346
    def _print_Function(self, expr):
347
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
348 349 350 351 352
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
353
        }
Martin Bauer's avatar
Martin Bauer committed
354 355
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
356 357 358
        if isinstance(expr, reinterpret_cast_func):
            arg, data_type = expr.args
            return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg))
359 360 361
        elif isinstance(expr, address_of):
            assert len(expr.args) == 1, "address_of must only have one argument"
            return "&(%s)" % self._print(expr.args[0])
362
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
363
            arg, data_type = expr.args
364
            if isinstance(arg, sp.Number) and arg.is_finite:
365 366
                return self._typed_number(arg, data_type)
            else:
367 368
                return "((%s)(%s))" % (data_type, self._print(arg))
        elif isinstance(expr, fast_division):
369
            return "({})".format(self._print(expr.args[0] / expr.args[1]))
370
        elif isinstance(expr, fast_sqrt):
371
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
372 373
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            return self._print(expr.args[0])
374
        elif isinstance(expr, fast_inv_sqrt):
375
            return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
376 377
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
378 379 380 381
        elif expr.func == int_power_of_2:
            return "(1 << (%s))" % (self._print(expr.args[0]))
        elif expr.func == int_div:
            return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
382
        else:
383
            return super(CustomSympyPrinter, self)._print_Function(expr)
Martin Bauer's avatar
Martin Bauer committed
384

385 386
    def _typed_number(self, number, dtype):
        res = self._print(number)
387 388 389 390
        if dtype.numpy_dtype == np.float32:
            return res + '.0f' if '.' not in res else res + 'f'
        elif dtype.numpy_dtype == np.float64:
            return res + '.0' if '.' not in res else res
391 392
        else:
            return res
393

Stephan Seitz's avatar
Stephan Seitz committed
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
    def _print_Sum(self, expr):
        template = """[&]() {{
    {dtype} sum = ({dtype}) 0;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        sum += {expr};
    }}
    return sum;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code

    def _print_Product(self, expr):
        template = """[&]() {{
    {dtype} product = ({dtype}) 1;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        product *= {expr};
    }}
    return product;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code
439

440 441 442
    def _print_ConditionalFieldAccess(self, node):
        return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))

443 444 445
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

446

Martin Bauer's avatar
Martin Bauer committed
447
# noinspection PyPep8Naming
448 449 450
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

451 452
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
453
        self.instruction_set = instruction_set
454

Martin Bauer's avatar
Martin Bauer committed
455 456 457 458
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
459
        else:
Martin Bauer's avatar
Martin Bauer committed
460
            assert self.instruction_set['width'] == expr_type.width
461 462
            return None

463
    def _print_Function(self, expr):
464
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
465
            arg, data_type, aligned, _, mask = expr.args
Martin Bauer's avatar
Martin Bauer committed
466 467
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
468
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
469 470
            arg, data_type = expr.args
            if type(data_type) is VectorType:
471 472 473 474 475 476 477 478 479
                if isinstance(arg, sp.Tuple):
                    is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
                    printed_args = [self._print(a) for a in arg]
                    instruction = 'makeVecBool' if is_boolean else 'makeVec'
                    return self.instruction_set[instruction].format(*printed_args)
                else:
                    is_boolean = get_type_of_expression(arg) == create_type("bool")
                    instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
                    return self.instruction_set[instruction].format(self._print(arg))
480
        elif expr.func == fast_division:
481 482
            result = self._scalarFallback('_print_Function', expr)
            if not result:
483 484
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
            return result
485 486 487
        elif expr.func == fast_sqrt:
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
        elif expr.func == fast_inv_sqrt:
488 489 490 491 492 493
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if self.instruction_set['rsqrt']:
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
                else:
                    return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
494 495 496 497 498 499 500 501 502 503 504 505 506
        elif isinstance(expr, vec_any):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['any'].format(self._print(expr.args[0]))
        elif isinstance(expr, vec_all):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['all'].format(self._print(expr.args[0]))

507 508
        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

509 510 511 512 513
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
514 515 516 517
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
518
            result = self.instruction_set['&'].format(result, item)
519 520 521 522 523 524 525
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
526 527 528 529
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
530
            result = self.instruction_set['|'].format(result, item)
531 532
        return result

533
    def _print_Add(self, expr, order=None):
534 535 536
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
537 538 539 540

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
541
                sign, t = self._print_Mul(term, inside_add=True)
542 543 544 545 546 547 548 549 550 551 552 553 554
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
555
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
556 557 558
            processed = func.format(processed, summand.term)
        return processed

559
    def _print_Pow(self, expr):
560 561 562
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
563

564
        one = self.instruction_set['makeVecConst'].format(1.0)
565

566 567
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
568
        elif expr.exp == -1:
569
            one = self.instruction_set['makeVecConst'].format(1.0)
570 571 572
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
573 574 575
        elif expr.exp == -0.5:
            root = self.instruction_set['sqrt'].format(self._print(expr.base))
            return self.instruction_set['/'].format(one, root)
576 577 578
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
579
        else:
580
            raise ValueError("Generic exponential not supported: " + str(expr))
581

Martin Bauer's avatar
Martin Bauer committed
582 583 584 585
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

586 587 588
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
617
            result = self.instruction_set['*'].format(result, item)
618 619 620 621

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
622 623
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
624

Martin Bauer's avatar
Martin Bauer committed
625
        if inside_add:
626 627 628
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
629
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
630 631 632
            else:
                return result

633
    def _print_Relational(self, expr):
634 635 636
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
637
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
638 639

    def _print_Equality(self, expr):
640 641 642
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
643
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
644 645

    def _print_Piecewise(self, expr):
646 647 648
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
649

Martin Bauer's avatar
Martin Bauer committed
650
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
651 652 653 654 655 656 657 658 659
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
660
        for true_expr, condition in reversed(expr.args[:-1]):
661
            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
662 663 664 665 666
                if not KERNCRAFT_NO_TERNARY_MODE:
                    result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
                                                           result)
                else:
                    print("Warning - skipping ternary op")
667 668 669
            else:
                # noinspection SpellCheckingInspection
                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
670
        return result