cbackend.py 28.9 KB
Newer Older
1
import re
Martin Bauer's avatar
Martin Bauer committed
2
from collections import namedtuple
3
from typing import Set
4

5
import numpy as np
6
7
import sympy as sp
from sympy.core import S
8
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
9
from sympy.printing.ccode import C89CodePrinter
10

11
from pystencils.astnodes import KernelFunction, Node
12
from pystencils.cpu.vectorization import vec_all, vec_any
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.data_types import (
Stephan Seitz's avatar
Stephan Seitz committed
14
15
    PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
    reinterpret_cast_func, vector_memory_access)
Martin Bauer's avatar
Martin Bauer committed
16
17
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
Stephan Seitz's avatar
Stephan Seitz committed
18
19
    bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
    int_div, int_power_of_2, modulo_ceil)
20

Martin Bauer's avatar
Martin Bauer committed
21
22
try:
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
23
24
except ImportError:
    from sympy.printing.ccode import CCodePrinter  # for sympy versions < 1.1
Martin Bauer's avatar
Martin Bauer committed
25

26
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
27

28
29
30

HEADER_REGEX = re.compile(r'^[<"].*[">]$')

31
32
KERNCRAFT_NO_TERNARY_MODE = False

Martin Bauer's avatar
Fixes    
Martin Bauer committed
33

34
35
36
37
38
def generate_c(ast_node: Node,
               signature_only: bool = False,
               dialect='c',
               custom_backend=None,
               with_globals=True) -> str:
Martin Bauer's avatar
Martin Bauer committed
39
40
    """Prints an abstract syntax tree node as C or CUDA code.

Stephan Seitz's avatar
Stephan Seitz committed
41
42
43
    This function does not need to distinguish for most AST nodes between C, C++ or CUDA code, it just prints 'C-like'
    code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different
    create_kernel functions.
Martin Bauer's avatar
Martin Bauer committed
44
45
46
47

    Args:
        ast_node:
        signature_only:
48
        dialect: 'c' or 'cuda'
Martin Bauer's avatar
Martin Bauer committed
49
50
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
51
    """
52
53
54
55
56
57
    global_declarations = get_global_declarations(ast_node)
    for d in global_declarations:
        if hasattr(ast_node, "global_variables"):
            ast_node.global_variables.update(d.symbols_defined)
        else:
            ast_node.global_variables = d.symbols_defined
58
59
60
    if custom_backend:
        printer = custom_backend
    elif dialect == 'c':
61
62
63
64
        try:
            instruction_set = ast_node.instruction_set
        except Exception:
            instruction_set = None
65
        printer = CBackend(signature_only=signature_only,
66
                           vector_instruction_set=instruction_set)
67
68
69
    elif dialect == 'cuda':
        from pystencils.backends.cuda_backend import CudaBackend
        printer = CudaBackend(signature_only=signature_only)
Stephan Seitz's avatar
Stephan Seitz committed
70
    elif dialect == 'opencl':
71
72
        from pystencils.backends.opencl_backend import OpenClBackend
        printer = OpenClBackend(signature_only=signature_only)
73
    else:
Martin Bauer's avatar
Martin Bauer committed
74
        raise ValueError("Unknown dialect: " + str(dialect))
75
76
    code = printer(ast_node)
    if not signature_only and isinstance(ast_node, KernelFunction):
77
78
79
80
        if with_globals and global_declarations:
            code = "\n" + code
            for declaration in global_declarations:
                code = printer(declaration) + "\n" + code
81
82
83
84
85
86
87
88

    return code


def get_global_declarations(ast):
    global_declarations = []

    def visit_node(sub_ast):
89
        nonlocal global_declarations
90
91
92
93
94
95
96
97
98
        if hasattr(sub_ast, "required_global_declarations"):
            global_declarations += sub_ast.required_global_declarations

        if hasattr(sub_ast, "args"):
            for node in sub_ast.args:
                visit_node(node)

    visit_node(ast)

Stephan Seitz's avatar
Stephan Seitz committed
99
    return sorted(set(global_declarations), key=str)
100
101


Martin Bauer's avatar
Martin Bauer committed
102
103
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
104
105
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
106
107
108
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
109
110
111
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
112
        if isinstance(a, (sp.Expr, Node)):
Martin Bauer's avatar
Martin Bauer committed
113
            headers.update(get_headers(a))
114

115
116
117
118
    for g in get_global_declarations(ast_node):
        if isinstance(g, Node):
            headers.update(get_headers(g))

119
120
121
    for h in headers:
        assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'

122
    return sorted(headers)
123
124


125
126
127
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------


128
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
129
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
130
        super(CustomCodeNode, self).__init__(parent=parent)
131
        self._code = "\n" + code
132
133
        self._symbols_read = set(symbols_read)
        self._symbols_defined = set(symbols_defined)
134
        self.headers = []
135

136
    def get_code(self, dialect, vector_instruction_set):
137
138
139
140
141
142
143
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
144
    def symbols_defined(self):
145
        return self._symbols_defined
146
147

    @property
Martin Bauer's avatar
Martin Bauer committed
148
    def undefined_symbols(self):
149
        return self._symbols_read - self._symbols_defined
150

151
152
153
154
155
156
    def __eq___(self, other):
        return self._code == other._code

    def __hash__(self):
        return hash(self._code)

157

158
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
159
160
161
162
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
        code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name)
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
163
        self.headers.append("<iostream>")
164
165
166
167


# ------------------------------------------- Printer ------------------------------------------------------------------

168

Martin Bauer's avatar
Martin Bauer committed
169
170
# noinspection PyPep8Naming
class CBackend:
171

Martin Bauer's avatar
Martin Bauer committed
172
    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
Martin Bauer's avatar
Martin Bauer committed
173
174
        if sympy_printer is None:
            if vector_instruction_set is not None:
175
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
176
            else:
177
                self.sympy_printer = CustomSympyPrinter()
178
        else:
Martin Bauer's avatar
Martin Bauer committed
179
            self.sympy_printer = sympy_printer
180

181
        self._vector_instruction_set = vector_instruction_set
182
        self._indent = "   "
183
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
184
        self._signatureOnly = signature_only
185
186

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
187
        prev_is = VectorType.instruction_set
188
        VectorType.instruction_set = self._vector_instruction_set
189
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
190
        VectorType.instruction_set = prev_is
191
        return result
192
193

    def _print(self, node):
Stephan Seitz's avatar
Stephan Seitz committed
194
195
        if isinstance(node, str):
            return node
196
        for cls in type(node).__mro__:
Martin Bauer's avatar
Martin Bauer committed
197
198
199
            method_name = "_print_" + cls.__name__
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
200
        raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__)
201

202
203
204
    def _print_Type(self, node):
        return str(node)

205
    def _print_KernelFunction(self, node):
206
        function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
207
        launch_bounds = ""
208
        if self._dialect == 'cuda':
209
210
211
212
213
            max_threads = node.indexing.max_threads_per_block()
            if max_threads:
                launch_bounds = "__launch_bounds__({}) ".format(max_threads)
        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
                                                          ", ".join(function_arguments))
214
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
215
            return func_declaration
216

217
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
218
        return func_declaration + "\n" + body
219
220

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
221
222
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
223
224

    def _print_PragmaBlock(self, node):
Martin Bauer's avatar
Martin Bauer committed
225
        return "%s\n%s" % (node.pragma_line, self._print_Block(node))
226
227

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
228
        counter_symbol = node.loop_counter_name
Martin Bauer's avatar
Martin Bauer committed
229
230
231
232
        start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
        condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop))
        update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
        loop_str = "for (%s; %s; %s)" % (start, condition, update)
233

Martin Bauer's avatar
Martin Bauer committed
234
        prefix = "\n".join(node.prefix_lines)
235
236
        if prefix:
            prefix += "\n"
Martin Bauer's avatar
Martin Bauer committed
237
        return "%s%s\n%s" % (prefix, loop_str, self._print(node.body))
238
239

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
240
        if node.is_declaration:
241
242
            if node.use_auto:
                data_type = 'auto '
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
243
            else:
244
245
246
247
248
249
                if node.is_const:
                    prefix = 'const '
                else:
                    prefix = ''
                data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "

250
251
            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
252
                                   self.sympy_printer.doprint(node.rhs))
253
        else:
Martin Bauer's avatar
Martin Bauer committed
254
            lhs_type = get_type_of_expression(node.lhs)
Martin Bauer's avatar
Martin Bauer committed
255
            printed_mask = ""
Martin Bauer's avatar
Martin Bauer committed
256
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
Martin Bauer's avatar
Martin Bauer committed
257
                arg, data_type, aligned, nontemporal, mask = node.lhs.args
Martin Bauer's avatar
Martin Bauer committed
258
259
260
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'
Martin Bauer's avatar
Martin Bauer committed
261
                if mask != True:  # NOQA
Martin Bauer's avatar
Martin Bauer committed
262
263
264
265
                    instr = 'maskStore' if aligned else 'maskStoreU'
                    printed_mask = self.sympy_printer.doprint(mask)
                    if self._vector_instruction_set['dataTypePrefix']['double'] == '__mm256d':
                        printed_mask = "_mm256_castpd_si256({})".format(printed_mask)
Martin Bauer's avatar
Martin Bauer committed
266

267
268
269
270
271
272
                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

273
                return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]),
Martin Bauer's avatar
Martin Bauer committed
274
275
                                                                  self.sympy_printer.doprint(rhs),
                                                                  printed_mask) + ';'
276
            else:
Martin Bauer's avatar
Martin Bauer committed
277
                return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs))
278
279

    def _print_TemporaryMemoryAllocation(self, node):
280
        align = 64
Martin Bauer's avatar
Martin Bauer committed
281
282
283
284
285
286
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
        code = "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};"
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
287
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
288
289
                           offset=int(node.offset(align)),
                           align=align)
290
291

    def _print_TemporaryMemoryFree(self, node):
292
        align = 64
Martin Bauer's avatar
Martin Bauer committed
293
        return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
294

Martin Bauer's avatar
Martin Bauer committed
295
    def _print_SkipIteration(self, _):
296
        return "continue;"
Martin Bauer's avatar
Martin Bauer committed
297

298
299
    def _print_CustomCodeNode(self, node):
        return node.get_code(self._dialect, self._vector_instruction_set)
300

301
302
303
304
305
306
    def _print_SourceCodeComment(self, node):
        return "/* " + node.text + " */"

    def _print_EmptyLine(self, node):
        return ""

307
    def _print_Conditional(self, node):
308
        if type(node.condition_expr) is BooleanTrue:
309
            return self._print_Block(node.true_block)
310
        elif type(node.condition_expr) is BooleanFalse:
311
            return self._print_Block(node.false_block)
312
313
314
        cond_type = get_type_of_expression(node.condition_expr)
        if isinstance(cond_type, VectorType):
            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
Martin Bauer's avatar
Martin Bauer committed
315
316
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
Martin Bauer's avatar
Martin Bauer committed
317
        result = "if (%s)\n%s " % (condition_expr, true_block)
Martin Bauer's avatar
Martin Bauer committed
318
319
        if node.false_block:
            false_block = self._print_Block(node.false_block)
Martin Bauer's avatar
Martin Bauer committed
320
            result += "else " + false_block
321
322
        return result

323
324
325
326

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
327
# noinspection PyPep8Naming
328
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
329

330
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
331
        super(CustomSympyPrinter, self).__init__()
332
        self._float_type = create_type("float32")
333
334
335
336
        if 'Min' in self.known_functions:
            del self.known_functions['Min']
        if 'Max' in self.known_functions:
            del self.known_functions['Max']
Martin Bauer's avatar
Martin Bauer committed
337

338
339
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
Martin Bauer's avatar
Martin Bauer committed
340
341
342
        if not expr.free_symbols:
            return self._typed_number(expr.evalf(), get_type_of_expression(expr))

343
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
344
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
345
346
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
347
348
349
350
351
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Martin Bauer's avatar
Martin Bauer committed
352
353
        res = str(expr.evalf().num)
        return res
354
355
356
357
358
359
360
361

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
362
363
        return result.replace("\n", "")

364
    def _print_Abs(self, expr):
365
        if expr.args[0].is_integer:
366
367
368
369
            return 'abs({0})'.format(self._print(expr.args[0]))
        else:
            return 'fabs({0})'.format(self._print(expr.args[0]))

370
371
372
    def _print_Type(self, node):
        return str(node)

373
    def _print_Function(self, expr):
374
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
375
376
377
378
379
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
380
        }
Martin Bauer's avatar
Martin Bauer committed
381
382
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
383
384
        if isinstance(expr, reinterpret_cast_func):
            arg, data_type = expr.args
385
            return "*((%s)(& %s))" % (self._print(PointerType(data_type, restrict=False)), self._print(arg))
386
387
388
        elif isinstance(expr, address_of):
            assert len(expr.args) == 1, "address_of must only have one argument"
            return "&(%s)" % self._print(expr.args[0])
389
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
390
            arg, data_type = expr.args
391
            if isinstance(arg, sp.Number) and arg.is_finite:
392
393
                return self._typed_number(arg, data_type)
            else:
394
395
                return "((%s)(%s))" % (data_type, self._print(arg))
        elif isinstance(expr, fast_division):
396
            return "({})".format(self._print(expr.args[0] / expr.args[1]))
397
        elif isinstance(expr, fast_sqrt):
398
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
399
400
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            return self._print(expr.args[0])
401
        elif isinstance(expr, fast_inv_sqrt):
402
            return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
Michael Kuron's avatar
Michael Kuron committed
403
404
405
        elif isinstance(expr, sp.Abs):
            return "abs({})".format(self._print(expr.args[0]))
        elif isinstance(expr, sp.Mod):
406
            if expr.args[0].is_integer and expr.args[1].is_integer:
Michael Kuron's avatar
Michael Kuron committed
407
408
409
                return "({} % {})".format(self._print(expr.args[0]), self._print(expr.args[1]))
            else:
                return "fmod({}, {})".format(self._print(expr.args[0]), self._print(expr.args[1]))
410
411
        elif expr.func in infix_functions:
            return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1]))
412
413
414
415
        elif expr.func == int_power_of_2:
            return "(1 << (%s))" % (self._print(expr.args[0]))
        elif expr.func == int_div:
            return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
416
        else:
417
            name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
418
            arg_str = ', '.join(self._print(a) for a in expr.args)
419
            return f'{name}({arg_str})'
Martin Bauer's avatar
Martin Bauer committed
420

421
422
    def _typed_number(self, number, dtype):
        res = self._print(number)
423
424
425
426
        if dtype.numpy_dtype == np.float32:
            return res + '.0f' if '.' not in res else res + 'f'
        elif dtype.numpy_dtype == np.float64:
            return res + '.0' if '.' not in res else res
427
428
        else:
            return res
429

Stephan Seitz's avatar
Stephan Seitz committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def _print_Sum(self, expr):
        template = """[&]() {{
    {dtype} sum = ({dtype}) 0;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        sum += {expr};
    }}
    return sum;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code

    def _print_Product(self, expr):
        template = """[&]() {{
    {dtype} product = ({dtype}) 1;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        product *= {expr};
    }}
    return product;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code
475

476
477
478
    def _print_ConditionalFieldAccess(self, node):
        return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))

479
480
481
    _print_Max = C89CodePrinter._print_Max
    _print_Min = C89CodePrinter._print_Min

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    def _print_re(self, expr):
        return f"real({self._print(expr.args[0])})"

    def _print_im(self, expr):
        return f"imag({self._print(expr.args[0])})"

    def _print_ImaginaryUnit(self, expr):
        return "complex<double>{0,1}"

    def _print_TypedImaginaryUnit(self, expr):
        if expr.dtype.numpy_dtype == np.complex64:
            return "complex<float>{0,1}"
        elif expr.dtype.numpy_dtype == np.complex128:
            return "complex<double>{0,1}"
        else:
            raise NotImplementedError(
                "only complex64 and complex128 supported")

    def _print_Complex(self, expr):
        return self._typed_number(expr, np.complex64)

503

Martin Bauer's avatar
Martin Bauer committed
504
# noinspection PyPep8Naming
505
506
507
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

508
509
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
510
        self.instruction_set = instruction_set
511

Martin Bauer's avatar
Martin Bauer committed
512
513
514
515
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
516
        else:
Martin Bauer's avatar
Martin Bauer committed
517
            assert self.instruction_set['width'] == expr_type.width
518
519
            return None

520
    def _print_Function(self, expr):
521
        if isinstance(expr, vector_memory_access):
Martin Bauer's avatar
Martin Bauer committed
522
            arg, data_type, aligned, _, mask = expr.args
Martin Bauer's avatar
Martin Bauer committed
523
524
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
525
        elif isinstance(expr, cast_func):
Martin Bauer's avatar
Martin Bauer committed
526
527
            arg, data_type = expr.args
            if type(data_type) is VectorType:
528
529
530
531
532
533
534
535
536
                if isinstance(arg, sp.Tuple):
                    is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
                    printed_args = [self._print(a) for a in arg]
                    instruction = 'makeVecBool' if is_boolean else 'makeVec'
                    return self.instruction_set[instruction].format(*printed_args)
                else:
                    is_boolean = get_type_of_expression(arg) == create_type("bool")
                    instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
                    return self.instruction_set[instruction].format(self._print(arg))
537
        elif expr.func == fast_division:
538
539
            result = self._scalarFallback('_print_Function', expr)
            if not result:
540
541
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]))
            return result
542
543
544
        elif expr.func == fast_sqrt:
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
        elif expr.func == fast_inv_sqrt:
545
546
547
548
549
550
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if self.instruction_set['rsqrt']:
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
                else:
                    return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
551
552
553
554
555
556
557
558
559
560
561
562
563
        elif isinstance(expr, vec_any):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['any'].format(self._print(expr.args[0]))
        elif isinstance(expr, vec_all):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['all'].format(self._print(expr.args[0]))

564
565
        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

566
567
568
569
570
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
571
572
573
574
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
575
            result = self.instruction_set['&'].format(result, item)
576
577
578
579
580
581
582
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
583
584
585
586
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Martin Bauer's avatar
Martin Bauer committed
587
            result = self.instruction_set['|'].format(result, item)
588
589
        return result

590
    def _print_Add(self, expr, order=None):
591
592
593
        result = self._scalarFallback('_print_Add', expr)
        if result:
            return result
594
595
596
597

        summands = []
        for term in expr.args:
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
598
                sign, t = self._print_Mul(term, inside_add=True)
599
600
601
602
603
604
605
606
607
608
609
610
611
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
Martin Bauer's avatar
Martin Bauer committed
612
            func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
613
614
615
            processed = func.format(processed, summand.term)
        return processed

616
    def _print_Pow(self, expr):
617
618
619
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
620

621
        one = self.instruction_set['makeVecConst'].format(1.0)
622

623
624
        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
625
        elif expr.exp == -1:
626
            one = self.instruction_set['makeVecConst'].format(1.0)
627
628
629
            return self.instruction_set['/'].format(one, self._print(expr.base))
        elif expr.exp == 0.5:
            return self.instruction_set['sqrt'].format(self._print(expr.base))
630
631
632
        elif expr.exp == -0.5:
            root = self.instruction_set['sqrt'].format(self._print(expr.base))
            return self.instruction_set['/'].format(one, root)
633
634
635
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return self.instruction_set['/'].format(one,
                                                    self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False)))
636
        else:
637
            raise ValueError("Generic exponential not supported: " + str(expr))
638

Martin Bauer's avatar
Martin Bauer committed
639
640
641
642
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

643
644
645
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
674
            result = self.instruction_set['*'].format(result, item)
675
676
677
678

        if len(b) > 0:
            denominator_str = b_str[0]
            for item in b_str[1:]:
Martin Bauer's avatar
Martin Bauer committed
679
680
                denominator_str = self.instruction_set['*'].format(denominator_str, item)
            result = self.instruction_set['/'].format(result, denominator_str)
681

Martin Bauer's avatar
Martin Bauer committed
682
        if inside_add:
683
684
685
            return sign, result
        else:
            if sign < 0:
Martin Bauer's avatar
Martin Bauer committed
686
                return self.instruction_set['*'].format(self._print(S.NegativeOne), result)
687
688
689
            else:
                return result

690
    def _print_Relational(self, expr):
691
692
693
        result = self._scalarFallback('_print_Relational', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
694
        return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
695
696

    def _print_Equality(self, expr):
697
698
699
        result = self._scalarFallback('_print_Equality', expr)
        if result:
            return result
Martin Bauer's avatar
Martin Bauer committed
700
        return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs))
701
702

    def _print_Piecewise(self, expr):
703
704
705
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result
706

Martin Bauer's avatar
Martin Bauer committed
707
        if expr.args[-1].cond.args[0] is not sp.sympify(True):
708
709
710
711
712
713
714
715
716
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
Martin Bauer's avatar
Martin Bauer committed
717
        for true_expr, condition in reversed(expr.args[:-1]):
718
            if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"):
719
720
721
722
723
                if not KERNCRAFT_NO_TERNARY_MODE:
                    result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr),
                                                           result)
                else:
                    print("Warning - skipping ternary op")
724
725
726
            else:
                # noinspection SpellCheckingInspection
                result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition))
727
        return result