cbackend.py 29.3 KB
Newer Older
1
import re
Martin Bauer's avatar
Martin Bauer committed
2
from collections import namedtuple
3
from typing import Set
4

5
import numpy as np
6
7
import sympy as sp
from sympy.core import S
8
9
from sympy.logic.boolalg import BooleanFalse, BooleanTrue

10
from pystencils.astnodes import KernelFunction, Node
11
from pystencils.cpu.vectorization import vec_all, vec_any
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.data_types import (
Stephan Seitz's avatar
Stephan Seitz committed
13
14
    PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
    reinterpret_cast_func, vector_memory_access)
Martin Bauer's avatar
Martin Bauer committed
15
16
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
Stephan Seitz's avatar
Stephan Seitz committed
17
18
    bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
    int_div, int_power_of_2, modulo_ceil)
19

Martin Bauer's avatar
Martin Bauer committed
20
21
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
22
except ImportError:
23
24
25
    try:
        from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
    except ImportError:
26
        from sympy.printing.c import C11CodePrinter as CCodePrinter  # for sympy versions > 1.6
Martin Bauer's avatar
Martin Bauer committed
27

28
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
29

30
31
32

HEADER_REGEX = re.compile(r'^[<"].*[">]$')

33
34
KERNCRAFT_NO_TERNARY_MODE = False

Martin Bauer's avatar
Fixes    
Martin Bauer committed
35

36
37
38
39
40
def generate_c(ast_node: Node,
               signature_only: bool = False,
               dialect='c',
               custom_backend=None,
               with_globals=True) -> str:
Martin Bauer's avatar
Martin Bauer committed
41
42
    """Prints an abstract syntax tree node as C or CUDA code.

Stephan Seitz's avatar
Stephan Seitz committed
43
44
45
    This function does not need to distinguish for most AST nodes between C, C++ or CUDA code, it just prints 'C-like'
    code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different
    create_kernel functions.
Martin Bauer's avatar
Martin Bauer committed
46
47
48
49

    Args:
        ast_node:
        signature_only:
50
        dialect: 'c' or 'cuda'
Martin Bauer's avatar
Martin Bauer committed
51
52
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
53
    """
54
55
56
57
58
59
    global_declarations = get_global_declarations(ast_node)
    for d in global_declarations:
        if hasattr(ast_node, "global_variables"):
            ast_node.global_variables.update(d.symbols_defined)
        else:
            ast_node.global_variables = d.symbols_defined
60
61
62
    if custom_backend:
        printer = custom_backend
    elif dialect == 'c':
63
64
65
66
        try:
            instruction_set = ast_node.instruction_set
        except Exception:
            instruction_set = None
67
        printer = CBackend(signature_only=signature_only,
68
                           vector_instruction_set=instruction_set)
69
70
71
    elif dialect == 'cuda':
        from pystencils.backends.cuda_backend import CudaBackend
        printer = CudaBackend(signature_only=signature_only)
Stephan Seitz's avatar
Stephan Seitz committed
72
    elif dialect == 'opencl':
73
74
        from pystencils.backends.opencl_backend import OpenClBackend
        printer = OpenClBackend(signature_only=signature_only)
75
    else:
Martin Bauer's avatar
Martin Bauer committed
76
        raise ValueError("Unknown dialect: " + str(dialect))
77
78
    code = printer(ast_node)
    if not signature_only and isinstance(ast_node, KernelFunction):
79
80
81
82
        if with_globals and global_declarations:
            code = "\n" + code
            for declaration in global_declarations:
                code = printer(declaration) + "\n" + code
83
84
85
86
87
88
89
90

    return code


def get_global_declarations(ast):
    global_declarations = []

    def visit_node(sub_ast):
91
        nonlocal global_declarations
92
93
94
95
96
97
98
99
100
        if hasattr(sub_ast, "required_global_declarations"):
            global_declarations += sub_ast.required_global_declarations

        if hasattr(sub_ast, "args"):
            for node in sub_ast.args:
                visit_node(node)

    visit_node(ast)

Stephan Seitz's avatar
Stephan Seitz committed
101
    return sorted(set(global_declarations), key=str)
102
103


Martin Bauer's avatar
Martin Bauer committed
104
105
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
106
107
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
108
109
110
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
111
112
113
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
114
        if isinstance(a, (sp.Expr, Node)):
Martin Bauer's avatar
Martin Bauer committed
115
            headers.update(get_headers(a))
116

117
118
119
120
    for g in get_global_declarations(ast_node):
        if isinstance(g, Node):
            headers.update(get_headers(g))

121
122
123
    for h in headers:
        assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'

124
    return sorted(headers)
125
126


127
128
129
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


130
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
131
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
132
        super(CustomCodeNode, self).__init__(parent=parent)
133
        self._code = "\n" + code
134
135
        self._symbols_read = set(symbols_read)
        self._symbols_defined = set(symbols_defined)
136
        self.headers = []
137

138
    def get_code(self, dialect, vector_instruction_set):
139
140
141
142
143
144
145
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
146
    def symbols_defined(self):
147
        return self._symbols_defined
148
149

    @property
Martin Bauer's avatar
Martin Bauer committed
150
    def undefined_symbols(self):
151
        return self._symbols_read - self._symbols_defined
152

153
154
155
156
157
158
    def __eq___(self, other):
        return self._code == other._code

    def __hash__(self):
        return hash(self._code)

159

160
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
161
162
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
163
        code = f'\nstd::cout << "{symbol_to_print.name}  =  " << {symbol_to_print.name} << std::endl; \n'
Martin Bauer's avatar
Martin Bauer committed
164
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
165
        self.headers.append("<iostream>")
166
167
168
169


# ------------------------------------------- Printer ------------------------------------------------------------------

170

Martin Bauer's avatar
Martin Bauer committed
171
172
# noinspection PyPep8Naming
class CBackend:
173

Martin Bauer's avatar
Martin Bauer committed
174
    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
Martin Bauer's avatar
Martin Bauer committed
175
176
        if sympy_printer is None:
            if vector_instruction_set is not None:
177
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
178
            else:
179
                self.sympy_printer = CustomSympyPrinter()
180
        else:
Martin Bauer's avatar
Martin Bauer committed
181
            self.sympy_printer = sympy_printer
182

183
        self._vector_instruction_set = vector_instruction_set
184
        self._indent = "   "
185
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
186
        self._signatureOnly = signature_only
187
188

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
189
        prev_is = VectorType.instruction_set
190
        VectorType.instruction_set = self._vector_instruction_set
191
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
192
        VectorType.instruction_set = prev_is
193
        return result
194
195

    def _print(self, node):
Stephan Seitz's avatar
Stephan Seitz committed
196
197
        if isinstance(node, str):
            return node
198
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
199
200
201
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
202
        raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__)
203

204
205
206
    def _print_Type(self, node):
        return str(node)

207
    def _print_KernelFunction(self, node):
208
        function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()]
209
        launch_bounds = ""
210
        if self._dialect == 'cuda':
211
212
            max_threads = node.indexing.max_threads_per_block()
            if max_threads:
213
                launch_bounds = f"__launch_bounds__({max_threads}) "
214
215
        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
                                                          ", ".join(function_arguments))
216
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
217
            return func_declaration
218

219
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
220
        return func_declaration + "\n" + body
221
222

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
223
224
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
225
226

    def _print_PragmaBlock(self, node):
227
        return f"{node.pragma_line}\n{self._print_Block(node)}"
228
229

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
230
        counter_symbol = node.loop_counter_name
231
232
233
234
        start = f"int {counter_symbol} = {self.sympy_printer.doprint(node.start)}"
        condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}"
        update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}"
        loop_str = f"for ({start}; {condition}; {update})"
235

Martin Bauer's avatar
Martin Bauer committed
236
        prefix = "\n".join(node.prefix_lines)
237
238
        if prefix:
            prefix += "\n"
239
        return f"{prefix}{loop_str}\n{self._print(node.body)}"
240
241

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
242
        if node.is_declaration:
243
244
            if node.use_auto:
                data_type = 'auto '
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
245
            else:
246
247
248
249
250
251
                if node.is_const:
                    prefix = 'const '
                else:
                    prefix = ''
                data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "

252
253
            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
254
                                   self.sympy_printer.doprint(node.rhs))
255
        else:
Martin Bauer's avatar
Martin Bauer committed
256
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
257
            printed_mask = ""
Martin Bauer's avatar
Martin Bauer committed
258
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
Martin Bauer's avatar
Martin Bauer committed
259
                arg, data_type, aligned, nontemporal, mask = node.lhs.args
Martin Bauer's avatar
Martin Bauer committed
260
261
262
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'
Martin Bauer's avatar
Martin Bauer committed
263
                if mask != True:  # NOQA
Martin Bauer's avatar
Martin Bauer committed
264
265
266
                    instr = 'maskStore' if aligned else 'maskStoreU'
                    printed_mask = self.sympy_printer.doprint(mask)
                    if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d':
267
                        printed_mask = f"_mm256_castpd_si256({printed_mask})"
Martin Bauer's avatar
Martin Bauer committed
268

269
270
271
272
273
274
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

275
                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
Martin Bauer's avatar
Martin Bauer committed
276
277
                                                                  self.sympy_printer.doprint(rhs),
                                                                  printed_mask) + ';'
278
            else:
279
                return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
280
281

    def _print_TemporaryMemoryAllocation(self, node):
282
        align = 64
Martin Bauer's avatar
Martin Bauer committed
283
284
285
286
287
288
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
289
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
290
291
                           offset=int(node.offset(align)),
                           align=align)
292
293

    def _print_TemporaryMemoryFree(self, node):
294
        align = 64
Martin Bauer's avatar
Martin Bauer committed
295
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
296

Martin Bauer's avatar
Martin Bauer committed
297
    def _print_SkipIteration(self, _):
298
        return "continue;"
Martin Bauer's avatar
Martin Bauer committed
299

300
301
    def _print_CustomCodeNode(self, node):
        return node.get_code(self._dialect, self._vector_instruction_set)
302

303
    def _print_SourceCodeComment(self, node):
304
        return f"/* {node.text } */"
305
306
307
308

    def _print_EmptyLine(self, node):
        return ""

309
    def _print_Conditional(self, node):
310
        if type(node.condition_expr) is BooleanTrue:
311
            return self._print_Block(node.true_block)
312
        elif type(node.condition_expr) is BooleanFalse:
313
            return self._print_Block(node.false_block)
314
315
316
        cond_type = get_type_of_expression(node.condition_expr)
        if isinstance(cond_type, VectorType):
            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
Martin Bauer's avatar
Martin Bauer committed
317
318
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
319
        result = f"if ({condition_expr})\n{true_block} "
Martin Bauer's avatar
Martin Bauer committed
320
321
        if node.false_block:
            false_block = self._print_Block(node.false_block)
322
            result += f"else {false_block}"
323
324
        return result

325
326
327
328

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
329
# noinspection PyPep8Naming
330
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
331

332
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
333
        super(CustomSympyPrinter, self).__init__()
334
        self._float_type = create_type("float32")
Martin Bauer's avatar
Martin Bauer committed
335

336
337
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
Martin Bauer's avatar
Martin Bauer committed
338
339
340
        if not expr.free_symbols:
            return self._typed_number(expr.evalf(), get_type_of_expression(expr))

341
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
342
            return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
343
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
344
            return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
345
346
347
348
349
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
350
351
        res = str(expr.evalf().num)
        return res
352
353
354
355
356
357
358
359

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
360
361
        return result.replace("\n", "")

362
    def _print_Abs(self, expr):
363
        if expr.args[0].is_integer:
364
            return f'abs({self._print(expr.args[0])})'
365
        else:
366
            return f'fabs({self._print(expr.args[0])})'
367

368
369
370
    def _print_Type(self, node):
        return str(node)

371
    def _print_Function(self, expr):
372
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
373
374
375
376
377
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
378
        }
Martin Bauer's avatar
Martin Bauer committed
379
380
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
381
382
        if isinstance(expr, reinterpret_cast_func):
            arg, data_type = expr.args
383
            return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
384
385
        elif isinstance(expr, address_of):
            assert len(expr.args) == 1, "address_of must only have one argument"
386
            return f"&({self._print(expr.args[0])})"
387
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
388
            arg, data_type = expr.args
389
            if isinstance(arg, sp.Number) and arg.is_finite:
390
391
                return self._typed_number(arg, data_type)
            else:
392
                return f"(({data_type})({self._print(arg)}))"
393
        elif isinstance(expr, fast_division):
394
            return f"({self._print(expr.args[0] / expr.args[1])})"
395
        elif isinstance(expr, fast_sqrt):
396
            return f"({self._print(sp.sqrt(expr.args[0]))})"
397
398
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            return self._print(expr.args[0])
399
        elif isinstance(expr, fast_inv_sqrt):
400
            return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
Michael Kuron's avatar
Michael Kuron committed
401
        elif isinstance(expr, sp.Abs):
402
            return f"abs({self._print(expr.args[0])})"
403
404
        elif isinstance(expr, sp.Max):
            return self._print(expr)
Michael Kuron's avatar
Michael Kuron committed
405
        elif isinstance(expr, sp.Mod):
406
            if expr.args[0].is_integer and expr.args[1].is_integer:
407
                return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
Michael Kuron's avatar
Michael Kuron committed
408
            else:
409
                return f"fmod({self._print(expr.args[0])}, {self._print(expr.args[1])})"
410
        elif expr.func in infix_functions:
411
            return f"({self._print(expr.args[0])} {infix_functions[expr.func]} {self._print(expr.args[1])})"
412
        elif expr.func == int_power_of_2:
413
            return f"(1 << ({self._print(expr.args[0])}))"
414
        elif expr.func == int_div:
415
            return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
416
        else:
417
            name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
418
            arg_str = ', '.join(self._print(a) for a in expr.args)
419
            return f'{name}({arg_str})'
Martin Bauer's avatar
Martin Bauer committed
420

421
422
    def _typed_number(self, number, dtype):
        res = self._print(number)
423
424
425
426
        if dtype.numpy_dtype == np.float32:
            return res + '.0f' if '.' not in res else res + 'f'
        elif dtype.numpy_dtype == np.float64:
            return res + '.0' if '.' not in res else res
427
428
        else:
            return res
429

Stephan Seitz's avatar
Stephan Seitz committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def _print_Sum(self, expr):
        template = """[&]() {{
    {dtype} sum = ({dtype}) 0;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        sum += {expr};
    }}
    return sum;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code

    def _print_Product(self, expr):
        template = """[&]() {{
    {dtype} product = ({dtype}) 1;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        product *= {expr};
    }}
    return product;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code
475

476
477
478
    def _print_ConditionalFieldAccess(self, node):
        return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))

479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def _print_Max(self, expr):
        def inner_print_max(args):
            if len(args) == 1:
                return self._print(args[0])
            half = len(args) // 2
            a = inner_print_max(args[:half])
            b = inner_print_max(args[half:])
            return f"(({a} > {b}) ? {a} : {b})"
        return inner_print_max(expr.args)

    def _print_Min(self, expr):
        def inner_print_min(args):
            if len(args) == 1:
                return self._print(args[0])
            half = len(args) // 2
            a = inner_print_min(args[:half])
            b = inner_print_min(args[half:])
            return f"(({a} < {b}) ? {a} : {b})"
        return inner_print_min(expr.args)
498

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    def _print_re(self, expr):
        return f"real({self._print(expr.args[0])})"

    def _print_im(self, expr):
        return f"imag({self._print(expr.args[0])})"

    def _print_ImaginaryUnit(self, expr):
        return "complex<double>{0,1}"

    def _print_TypedImaginaryUnit(self, expr):
        if expr.dtype.numpy_dtype == np.complex64:
            return "complex<float>{0,1}"
        elif expr.dtype.numpy_dtype == np.complex128:
            return "complex<double>{0,1}"
        else:
            raise NotImplementedError(
                "only complex64 and complex128 supported")

    def _print_Complex(self, expr):
        return self._typed_number(expr, np.complex64)

520

Martin Bauer's avatar
Martin Bauer committed
521
# noinspection PyPep8Naming
522
523
524
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

525
526
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
527
        self.instruction_set = instruction_set
528

Martin Bauer's avatar
Martin Bauer committed
529
530
531
532
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
533
        else:
Martin Bauer's avatar
Martin Bauer committed
534
            assert self.instruction_set['width'] == expr_type.width
535
536
            return None

537
    def _print_Function(self, expr):
538
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
539
            arg, data_type, aligned, _, mask = expr.args
Martin Bauer's avatar
Martin Bauer committed
540
541
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
542
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
543
544
            arg, data_type = expr.args
            if type(data_type) is VectorType:
545
546
547
548
549
550
551
552
553
                if isinstance(arg, sp.Tuple):
                    is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
                    printed_args = [self._print(a) for a in arg]
                    instruction = 'makeVecBool' if is_boolean else 'makeVec'
                    return self.instruction_set[instruction].format(*printed_args)
                else:
                    is_boolean = get_type_of_expression(arg) == create_type("bool")
                    instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
                    return self.instruction_set[instruction].format(self._print(arg))
554
        elif expr.func == fast_division:
555
556
            result = self._scalarFallback('_print_Function', expr)
            if not result:
557
558
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
            return result
559
        elif expr.func == fast_sqrt:
560
            return f"({self._print(sp.sqrt(expr.args[0]))})"
561
        elif expr.func == fast_inv_sqrt:
562
563
564
565
566
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if self.instruction_set['rsqrt']:
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
                else:
567
                    return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
568
569
570
571
572
573
574
575
576
577
578
579
580
        elif isinstance(expr, vec_any):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['any'].format(self._print(expr.args[0]))
        elif isinstance(expr, vec_all):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['all'].format(self._print(expr.args[0]))

581
582
        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

583
584
585
586
587
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
588
589
590
591
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
592
            result = self.instruction_set['&'].format(result, item)
593
594
595
596
597
598
599
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
600
601
602
603
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
604
            result = self.instruction_set['|'].format(result, item)
605
606
        return result

607
    def _print_Add(self, expr, order=None):
608
609
610
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
611
612
613
614

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
615
                sign, t = self._print_Mul(term, inside_add=True)
616
617
618
619
620
621
622
623
624
625
626
627
628
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
629
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
630
631
632
            processed = func.format(processed, summand.term)
        return processed

633
    def _print_Pow(self, expr):
634
635
636
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
637

638
        one = self.instruction_set['makeVecConst'].format(1.0)
639

640
641
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
642
        elif expr.exp == -1:
643
            one = self.instruction_set['makeVecConst'].format(1.0)
644
645
646
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
647
648
649
        elif expr.exp == -0.5:
            root = self.instruction_set['sqrt'].format(self._print(expr.base))
            return self.instruction_set['/'].format(one, root)
650
651
652
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
653
        else:
654
            raise ValueError("Generic exponential not supported: " + str(expr))
655

Martin Bauer's avatar
Martin Bauer committed
656
657
658
659
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

660
661
662
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
691
            result = self.instruction_set['*'].format(result, item)
692
693
694
695

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
696
697
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
698

Martin Bauer's avatar
Martin Bauer committed
699
        if inside_add:
700
701
702
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
703
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
704
705
706
            else:
                return result

707
    def _print_Relational(self, expr):
708
709
710
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
711
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
712
713

    def _print_Equality(self, expr):
714
715
716
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
717
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
718
719

    def _print_Piecewise(self, expr):
720
721
722
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
723

Martin Bauer's avatar
Martin Bauer committed
724
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
725
726
727
728
729
730
731
732
733
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
734
        for true_expr, condition in reversed(expr.args[:-1]):
735
            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
736
737
738
739
740
                if not KERNCRAFT_NO_TERNARY_MODE:
                    result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
                                                           result)
                else:
                    print("Warning - skipping ternary op")
741
742
743
            else:
                # noinspection SpellCheckingInspection
                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
744
        return result