cbackend.py 40.3 KB
Newer Older
1
import re
Martin Bauer's avatar
Martin Bauer committed
2
from collections import namedtuple
Michael Kuron's avatar
Michael Kuron committed
3
import hashlib
4
from typing import Set
5

6
import numpy as np
7
8
import sympy as sp
from sympy.core import S
Michael Kuron's avatar
Michael Kuron committed
9
from sympy.core.cache import cacheit
10
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
11
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
12
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
13

14
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
15
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
Markus Holzer's avatar
Markus Holzer committed
16
17
18
from pystencils.typing import (
    PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
    ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
Jan Hönig's avatar
Jan Hönig committed
19
from pystencils.enums import Backend
Martin Bauer's avatar
Martin Bauer committed
20
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
Markus Holzer's avatar
Markus Holzer committed
21
from pystencils.functions import DivFunc, AddressOf
Martin Bauer's avatar
Martin Bauer committed
22
from pystencils.integer_functions import (
Stephan Seitz's avatar
Stephan Seitz committed
23
24
    bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
    int_div, int_power_of_2, modulo_ceil)
25

Martin Bauer's avatar
Martin Bauer committed
26
try:
27
    from sympy.printing.c import C99CodePrinter as CCodePrinter  # for sympy versions > 1.6
Martin Bauer's avatar
Martin Bauer committed
28
except ImportError:
29
    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
Martin Bauer's avatar
Martin Bauer committed
30

31
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
32

33
34
35

HEADER_REGEX = re.compile(r'^[<"].*[">]$')

Martin Bauer's avatar
Fixes    
Martin Bauer committed
36

37
38
def generate_c(ast_node: Node,
               signature_only: bool = False,
Jan Hönig's avatar
Jan Hönig committed
39
               dialect: Backend = Backend.C,
40
41
               custom_backend=None,
               with_globals=True) -> str:
Martin Bauer's avatar
Martin Bauer committed
42
43
    """Prints an abstract syntax tree node as C or CUDA code.

Stephan Seitz's avatar
Stephan Seitz committed
44
45
46
    This function does not need to distinguish for most AST nodes between C, C++ or CUDA code, it just prints 'C-like'
    code as encoded in the abstract syntax tree (AST). The AST is built differently for C or CUDA by calling different
    create_kernel functions.
Martin Bauer's avatar
Martin Bauer committed
47
48

    Args:
Markus Holzer's avatar
Markus Holzer committed
49
50
        ast_node: ast representation of kernel
        signature_only: generate signature without function body
Markus Holzer's avatar
Markus Holzer committed
51
        dialect: `Backend`: 'C' or 'CUDA'
Markus Holzer's avatar
Markus Holzer committed
52
53
        custom_backend: use own custom printer for code generation
        with_globals: enable usage of global variables
Martin Bauer's avatar
Martin Bauer committed
54
55
    Returns:
        C-like code for the ast node and its descendants
Martin Bauer's avatar
Martin Bauer committed
56
    """
57
58
59
60
61
62
    global_declarations = get_global_declarations(ast_node)
    for d in global_declarations:
        if hasattr(ast_node, "global_variables"):
            ast_node.global_variables.update(d.symbols_defined)
        else:
            ast_node.global_variables = d.symbols_defined
63
64
    if custom_backend:
        printer = custom_backend
Jan Hönig's avatar
Jan Hönig committed
65
    elif dialect == Backend.C:
66
        try:
67
            # TODO Vectorization Revamp: instruction_set should not be just slapped on ast
68
69
70
            instruction_set = ast_node.instruction_set
        except Exception:
            instruction_set = None
71
        printer = CBackend(signature_only=signature_only,
72
                           vector_instruction_set=instruction_set)
Jan Hönig's avatar
Jan Hönig committed
73
    elif dialect == Backend.CUDA:
74
75
76
        from pystencils.backends.cuda_backend import CudaBackend
        printer = CudaBackend(signature_only=signature_only)
    else:
Jan Hönig's avatar
Jan Hönig committed
77
        raise ValueError(f'Unknown {dialect=}')
78
79
    code = printer(ast_node)
    if not signature_only and isinstance(ast_node, KernelFunction):
80
81
82
83
        if with_globals and global_declarations:
            code = "\n" + code
            for declaration in global_declarations:
                code = printer(declaration) + "\n" + code
84
85
86
87
88
89
90
91

    return code


def get_global_declarations(ast):
    global_declarations = []

    def visit_node(sub_ast):
92
        nonlocal global_declarations
93
94
95
96
97
98
99
100
101
        if hasattr(sub_ast, "required_global_declarations"):
            global_declarations += sub_ast.required_global_declarations

        if hasattr(sub_ast, "args"):
            for node in sub_ast.args:
                visit_node(node)

    visit_node(ast)

Stephan Seitz's avatar
Stephan Seitz committed
102
    return sorted(set(global_declarations), key=str)
103
104


Martin Bauer's avatar
Martin Bauer committed
105
106
def get_headers(ast_node: Node) -> Set[str]:
    """Return a set of header files, necessary to compile the printed C-like code."""
107
108
    headers = set()

Martin Bauer's avatar
Martin Bauer committed
109
110
111
    if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
        headers.update(ast_node.instruction_set['headers'])

Martin Bauer's avatar
Martin Bauer committed
112
113
114
    if hasattr(ast_node, 'headers'):
        headers.update(ast_node.headers)
    for a in ast_node.args:
115
        if isinstance(a, (sp.Expr, Node)):
Martin Bauer's avatar
Martin Bauer committed
116
            headers.update(get_headers(a))
117

118
119
120
121
    for g in get_global_declarations(ast_node):
        if isinstance(g, Node):
            headers.update(get_headers(g))

122
123
124
    for h in headers:
        assert HEADER_REGEX.match(h), f'header /{h}/ does not follow the pattern /"..."/ or /<...>/'

Markus Holzer's avatar
Testing    
Markus Holzer committed
125
    return headers
126
127


128
129
# --------------------------------------- Backend Specific Nodes -------------------------------------------------------

130
# TODO future CustomCodeNode should not be backend specific move it elsewhere
131
class CustomCodeNode(Node):
Martin Bauer's avatar
Martin Bauer committed
132
    def __init__(self, code, symbols_read, symbols_defined, parent=None):
133
        super(CustomCodeNode, self).__init__(parent=parent)
134
        self._code = "\n" + code
135
136
        self._symbols_read = set(symbols_read)
        self._symbols_defined = set(symbols_defined)
137
        self.headers = []
138

139
    def get_code(self, dialect, vector_instruction_set, print_arg):
140
141
142
143
144
145
146
        return self._code

    @property
    def args(self):
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
147
    def symbols_defined(self):
148
        return self._symbols_defined
149
150

    @property
Martin Bauer's avatar
Martin Bauer committed
151
    def undefined_symbols(self):
152
        return self._symbols_read - self._symbols_defined
153

154
155
156
157
158
159
    def __eq___(self, other):
        return self._code == other._code

    def __hash__(self):
        return hash(self._code)

160

161
class PrintNode(CustomCodeNode):
Martin Bauer's avatar
Martin Bauer committed
162
163
    # noinspection SpellCheckingInspection
    def __init__(self, symbol_to_print):
164
        code = f'\nstd::cout << "{symbol_to_print.name}  =  " << {symbol_to_print.name} << std::endl; \n'
Martin Bauer's avatar
Martin Bauer committed
165
        super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set())
166
        self.headers.append("<iostream>")
167
168


Michael Kuron's avatar
Michael Kuron committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class CFunction(TypedSymbol):
    def __new__(cls, function, dtype):
        return CFunction.__xnew_cached_(cls, function, dtype)

    def __new_stage2__(cls, function, dtype):
        return super(CFunction, cls).__xnew__(cls, function, dtype)

    __xnew__ = staticmethod(__new_stage2__)
    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))

    def __getnewargs__(self):
        return self.name, self.dtype

    def __getnewargs_ex__(self):
        return (self.name, self.dtype), {}


186
187
# ------------------------------------------- Printer ------------------------------------------------------------------

188

Martin Bauer's avatar
Martin Bauer committed
189
190
# noinspection PyPep8Naming
class CBackend:
191

Jan Hönig's avatar
Jan Hönig committed
192
    def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect=Backend.C):
Martin Bauer's avatar
Martin Bauer committed
193
194
        if sympy_printer is None:
            if vector_instruction_set is not None:
195
                self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
196
            else:
197
                self.sympy_printer = CustomSympyPrinter()
198
        else:
Martin Bauer's avatar
Martin Bauer committed
199
            self.sympy_printer = sympy_printer
200

201
        self._vector_instruction_set = vector_instruction_set
202
        self._indent = "   "
203
        self._dialect = dialect
Martin Bauer's avatar
Martin Bauer committed
204
        self._signatureOnly = signature_only
Michael Kuron's avatar
Michael Kuron committed
205
206
        self._kwargs = {}
        self.sympy_printer._kwargs = self._kwargs
207
208

    def __call__(self, node):
Martin Bauer's avatar
Martin Bauer committed
209
        prev_is = VectorType.instruction_set
210
        VectorType.instruction_set = self._vector_instruction_set
211
        result = str(self._print(node))
Martin Bauer's avatar
Martin Bauer committed
212
        VectorType.instruction_set = prev_is
213
        return result
214
215

    def _print(self, node):
Stephan Seitz's avatar
Stephan Seitz committed
216
217
        if isinstance(node, str):
            return node
218
        for cls in type(node).__mro__:
Markus Holzer's avatar
Markus Holzer committed
219
            method_name = f"_print_{cls.__name__}"
Martin Bauer's avatar
Martin Bauer committed
220
221
            if hasattr(self, method_name):
                return getattr(self, method_name)(node)
222
        raise NotImplementedError(f"{self.__class__.__name__} does not support node of type {node.__class__.__name__}")
223

Markus Holzer's avatar
Markus Holzer committed
224
    def _print_AbstractType(self, node):
225
226
        return str(node)

227
    def _print_KernelFunction(self, node):
Michael Kuron's avatar
Michael Kuron committed
228
229
        function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" for s in node.get_parameters()
                              if not type(s.symbol) is CFunction]
230
        launch_bounds = ""
Jan Hönig's avatar
Jan Hönig committed
231
        if self._dialect == Backend.CUDA:
232
233
            max_threads = node.indexing.max_threads_per_block()
            if max_threads:
234
                launch_bounds = f"__launch_bounds__({max_threads}) "
235
236
        func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name,
                                                          ", ".join(function_arguments))
237
        if self._signatureOnly:
Martin Bauer's avatar
Martin Bauer committed
238
            return func_declaration
239

240
        body = self._print(node.body)
Martin Bauer's avatar
Martin Bauer committed
241
        return func_declaration + "\n" + body
242
243

    def _print_Block(self, node):
Martin Bauer's avatar
Martin Bauer committed
244
245
        block_contents = "\n".join([self._print(child) for child in node.args])
        return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True)))
246
247

    def _print_PragmaBlock(self, node):
248
        return f"{node.pragma_line}\n{self._print_Block(node)}"
249
250

    def _print_LoopOverCoordinate(self, node):
Martin Bauer's avatar
Martin Bauer committed
251
        counter_symbol = node.loop_counter_name
252
        start = f"int64_t {counter_symbol} = {self.sympy_printer.doprint(node.start)}"
253
254
255
        condition = f"{counter_symbol} < {self.sympy_printer.doprint(node.stop)}"
        update = f"{counter_symbol} += {self.sympy_printer.doprint(node.step)}"
        loop_str = f"for ({start}; {condition}; {update})"
Michael Kuron's avatar
Michael Kuron committed
256
257
        self._kwargs['loop_counter'] = counter_symbol
        self._kwargs['loop_stop'] = node.stop
258

Martin Bauer's avatar
Martin Bauer committed
259
        prefix = "\n".join(node.prefix_lines)
260
261
        if prefix:
            prefix += "\n"
262
        return f"{prefix}{loop_str}\n{self._print(node.body)}"
263
264

    def _print_SympyAssignment(self, node):
Martin Bauer's avatar
Martin Bauer committed
265
        if node.is_declaration:
266
267
            if node.use_auto:
                data_type = 'auto '
Stephan Seitz's avatar
Lint    
Stephan Seitz committed
268
            else:
269
270
271
272
273
274
                if node.is_const:
                    prefix = 'const '
                else:
                    prefix = ''
                data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "

275
276
            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
277
                                   self.sympy_printer.doprint(node.rhs))
278
        else:
Markus Holzer's avatar
Markus Holzer committed
279
            lhs_type = get_type_of_expression(node.lhs)  # TOOD: this should have been typed
Martin Bauer's avatar
Martin Bauer committed
280
            printed_mask = ""
Markus Holzer's avatar
Markus Holzer committed
281
            if type(lhs_type) is VectorType and isinstance(node.lhs, CastFunc):
Michael Kuron's avatar
Michael Kuron committed
282
                arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
Martin Bauer's avatar
Martin Bauer committed
283
284
                instr = 'storeU'
                if aligned:
285
                    instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
Martin Bauer's avatar
Martin Bauer committed
286
                if mask != True:  # NOQA
Michael Kuron's avatar
Michael Kuron committed
287
288
289
290
                    instr = 'maskStoreA' if aligned else 'maskStoreU'
                    if instr not in self._vector_instruction_set:
                        self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format(
                            '{0}', self._vector_instruction_set['blendv'].format(
Michael Kuron's avatar
Michael Kuron committed
291
292
                                self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
                                '{1}', '{2}', **self._kwargs), **self._kwargs)
Martin Bauer's avatar
Martin Bauer committed
293
                    printed_mask = self.sympy_printer.doprint(mask)
Markus Holzer's avatar
Markus Holzer committed
294
                    if data_type.base_type.c_name == 'double':
295
296
297
298
                        if self._vector_instruction_set['double'] == '__m256d':
                            printed_mask = f"_mm256_castpd_si256({printed_mask})"
                        elif self._vector_instruction_set['double'] == '__m128d':
                            printed_mask = f"_mm_castpd_si128({printed_mask})"
Markus Holzer's avatar
Markus Holzer committed
299
                    elif data_type.base_type.c_name == 'float':
300
301
302
303
                        if self._vector_instruction_set['float'] == '__m256':
                            printed_mask = f"_mm256_castps_si256({printed_mask})"
                        elif self._vector_instruction_set['float'] == '__m128':
                            printed_mask = f"_mm_castps_si128({printed_mask})"
Martin Bauer's avatar
Martin Bauer committed
304

305
                rhs_type = get_type_of_expression(node.rhs)
306
                if type(rhs_type) is not VectorType:
Jan Hönig's avatar
Jan Hönig committed
307
308
                    raise ValueError(f'Cannot vectorize {node.rhs} of type {rhs_type} inside of the pretty printer! '
                                     f'This should have happen earlier!')
309
                    # rhs = CastFunc(node.rhs, VectorType(rhs_type)) # Unknown width
310
311
312
                else:
                    rhs = node.rhs

313
                ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])
Michael Kuron's avatar
Michael Kuron committed
314
315

                if stride != 1:
Michael Kuron's avatar
Michael Kuron committed
316
                    instr = 'maskStoreS' if mask != True else 'storeS'  # NOQA
Michael Kuron's avatar
Michael Kuron committed
317
                    return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
Michael Kuron's avatar
Michael Kuron committed
318
                                                                      stride, printed_mask, **self._kwargs) + ';'
Michael Kuron's avatar
Michael Kuron committed
319

320
                pre_code = ''
321
                if nontemporal and 'cachelineZero' in self._vector_instruction_set:
322
323
324
325
                    first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
                    offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
                                      * node.lhs.args[0].field.spatial_strides[i] for i in
                                      range(len(node.lhs.args[0].field.spatial_strides))])
Markus Holzer's avatar
Markus Holzer committed
326
327
                    if stride == 1:
                        offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
328
                    size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
Markus Holzer's avatar
Markus Holzer committed
329
                    element_size = 8 if data_type.base_type.c_name == 'double' else 4
330
331
                    size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
                    pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
Michael Kuron's avatar
Michael Kuron committed
332
                        self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
333
334

                code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
Michael Kuron's avatar
Michael Kuron committed
335
                                                                  printed_mask, **self._kwargs) + ';'
Michael Kuron's avatar
Michael Kuron committed
336
                flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
337
                if nontemporal and 'flushCacheline' in self._vector_instruction_set:
338
                    code2 = self._vector_instruction_set['flushCacheline'].format(
Michael Kuron's avatar
Michael Kuron committed
339
                        ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
340
                    code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
Michael Kuron's avatar
Michael Kuron committed
341
                elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
Michael Kuron's avatar
Michael Kuron committed
342
343
344
                    tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
                    code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
                        + self.sympy_printer.doprint(rhs) + ';'
Michael Kuron's avatar
Michael Kuron committed
345
346
347
                    code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
                    code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask,
                                                                                           **self._kwargs) + ';'
Michael Kuron's avatar
Michael Kuron committed
348
                    code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
349
                return pre_code + code
350
            else:
351
                return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
352

353
    def _print_NontemporalFence(self, _):
354
355
356
357
358
359
360
        if 'streamFence' in self._vector_instruction_set:
            return self._vector_instruction_set['streamFence'] + ';'
        else:
            return ''

    def _print_CachelineSize(self, node):
        if 'cachelineSize' in self._vector_instruction_set:
361
362
            code = f'const size_t {node.symbol} = {self._vector_instruction_set["cachelineSize"]};\n'
            code += f'const size_t {node.mask_symbol} = {node.symbol} - 1;\n'
Michael Kuron's avatar
Michael Kuron committed
363
364
            vectorsize = self._vector_instruction_set['bytes']
            code += f'const size_t {node.last_symbol} = {node.symbol} - {vectorsize};\n'
365
            return code
366
367
368
        else:
            return ''

369
    def _print_TemporaryMemoryAllocation(self, node):
Michael Kuron's avatar
Michael Kuron committed
370
371
        if self._vector_instruction_set:
            align = self._vector_instruction_set['bytes']
372
        else:
373
            align = node.symbol.dtype.base_type.numpy_dtype.itemsize
374

Martin Bauer's avatar
Martin Bauer committed
375
376
377
        np_dtype = node.symbol.dtype.base_type.numpy_dtype
        required_size = np_dtype.itemsize * node.size + align
        size = modulo_ceil(required_size, align)
378
        code = "#if defined(_MSC_VER)\n"
Michael Kuron's avatar
Michael Kuron committed
379
        code += "{dtype} {name}=({dtype})_aligned_malloc({size}, {align}) + {offset};\n"
380
381
        code += "#elif __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L\n"
        code += "{dtype} {name}=({dtype})aligned_alloc({align}, {size}) + {offset};\n"
Michael Kuron's avatar
Michael Kuron committed
382
383
384
385
386
        code += "#else\n"
        code += "{dtype} {name};\n"
        code += "posix_memalign((void**) &{name}, {align}, {size});\n"
        code += "{name} += {offset};\n"
        code += "#endif"
Martin Bauer's avatar
Martin Bauer committed
387
388
        return code.format(dtype=node.symbol.dtype,
                           name=self.sympy_printer.doprint(node.symbol.name),
389
                           size=self.sympy_printer.doprint(size),
Martin Bauer's avatar
Martin Bauer committed
390
391
                           offset=int(node.offset(align)),
                           align=align)
392
393

    def _print_TemporaryMemoryFree(self, node):
Michael Kuron's avatar
Michael Kuron committed
394
395
        if self._vector_instruction_set:
            align = self._vector_instruction_set['bytes']
396
        else:
397
            align = node.symbol.dtype.base_type.numpy_dtype.itemsize
398

Michael Kuron's avatar
Michael Kuron committed
399
400
401
402
403
404
        code = "#if defined(_MSC_VER)\n"
        code += "_aligned_free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
        code += "#else\n"
        code += "free(%s - %d);\n" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align))
        code += "#endif"
        return code
405

Martin Bauer's avatar
Martin Bauer committed
406
    def _print_SkipIteration(self, _):
407
        return "continue;"
Martin Bauer's avatar
Martin Bauer committed
408

409
    def _print_CustomCodeNode(self, node):
410
        return node.get_code(self._dialect, self._vector_instruction_set, print_arg=self.sympy_printer._print)
411

412
    def _print_SourceCodeComment(self, node):
413
        return f"/* {node.text } */"
414
415
416
417

    def _print_EmptyLine(self, node):
        return ""

418
    def _print_Conditional(self, node):
419
        if type(node.condition_expr) is BooleanTrue:
420
            return self._print_Block(node.true_block)
421
        elif type(node.condition_expr) is BooleanFalse:
422
            return self._print_Block(node.false_block)
423
        cond_type = get_type_of_expression(node.condition_expr)
424
425
        if isinstance(cond_type, VectorType):
            raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
Martin Bauer's avatar
Martin Bauer committed
426
427
        condition_expr = self.sympy_printer.doprint(node.condition_expr)
        true_block = self._print_Block(node.true_block)
428
        result = f"if ({condition_expr})\n{true_block} "
Martin Bauer's avatar
Martin Bauer committed
429
430
        if node.false_block:
            false_block = self._print_Block(node.false_block)
431
            result += f"else {false_block}"
432
433
        return result

434
435
436
437

# ------------------------------------------ Helper function & classes -------------------------------------------------


Martin Bauer's avatar
Martin Bauer committed
438
# noinspection PyPep8Naming
439
class CustomSympyPrinter(CCodePrinter):
Martin Bauer's avatar
Martin Bauer committed
440

441
    def __init__(self):
Martin Bauer's avatar
Martin Bauer committed
442
443
        super(CustomSympyPrinter, self).__init__()

444
445
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
Martin Bauer's avatar
Martin Bauer committed
446
        if not expr.free_symbols:
Markus Holzer's avatar
Markus Holzer committed
447
448
449
            raise NotImplementedError("This pow should be simplified already?")
            # return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
        return super(CustomSympyPrinter, self)._print_Pow(expr)
Martin Bauer's avatar
Martin Bauer committed
450

Markus Holzer's avatar
Markus Holzer committed
451
    # TODO don't print ones in sp.Mul
452
453
454

    def _print_Rational(self, expr):
        """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
Markus Holzer's avatar
Markus Holzer committed
455
        res = str(expr.evalf(17))
Martin Bauer's avatar
Martin Bauer committed
456
        return res
457
458
459
460
461
462
463
464

    def _print_Equality(self, expr):
        """Equality operator is not printable in default printer"""
        return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))'

    def _print_Piecewise(self, expr):
        """Print piecewise in one line (remove newlines)"""
        result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
Martin Bauer's avatar
Martin Bauer committed
465
466
        return result.replace("\n", "")

467
    def _print_Abs(self, expr):
468
        if expr.args[0].is_integer:
469
            return f'abs({self._print(expr.args[0])})'
470
        else:
471
            return f'fabs({self._print(expr.args[0])})'
472

Markus Holzer's avatar
Markus Holzer committed
473
    def _print_AbstractType(self, node):
Jan Hönig's avatar
Jan Hönig committed
474
        return str(node)
475

476
    def _print_Function(self, expr):
477
        infix_functions = {
Martin Bauer's avatar
Martin Bauer committed
478
479
480
481
482
            bitwise_xor: '^',
            bit_shift_right: '>>',
            bit_shift_left: '<<',
            bitwise_or: '|',
            bitwise_and: '&',
Martin Bauer's avatar
Martin Bauer committed
483
        }
Martin Bauer's avatar
Martin Bauer committed
484
485
        if hasattr(expr, 'to_c'):
            return expr.to_c(self._print)
Markus Holzer's avatar
Markus Holzer committed
486
        if isinstance(expr, ReinterpretCastFunc):
487
            arg, data_type = expr.args
488
            return f"*(({self._print(PointerType(data_type, restrict=False))})(& {self._print(arg)}))"
Markus Holzer's avatar
Markus Holzer committed
489
        elif isinstance(expr, AddressOf):
490
            assert len(expr.args) == 1, "address_of must only have one argument"
491
            return f"&({self._print(expr.args[0])})"
Markus Holzer's avatar
Markus Holzer committed
492
        elif isinstance(expr, CastFunc):
Martin Bauer's avatar
Martin Bauer committed
493
            arg, data_type = expr.args
Markus Holzer's avatar
Markus Holzer committed
494
            if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
Jan Hönig's avatar
Jan Hönig committed
495
                return self._typed_number(arg, data_type)
496
497
498
            elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
                    and data_type == BasicType('float32'):
                known = self.known_functions[arg.__class__.__name__.lower()]
499
500
                code = self._print(arg)
                return code.replace(known, f"{known}f")
Markus Holzer's avatar
Markus Holzer committed
501
502
            elif isinstance(arg, (sp.Pow, sp.exp)) and data_type == BasicType('float32'):
                known = ['sqrt', 'cbrt', 'pow', 'exp']
503
504
505
506
507
                code = self._print(arg)
                for k in known:
                    if k in code:
                        return code.replace(k, f'{k}f')
                raise ValueError(f"{code} doesn't give {known=} function back.")
Jan Hönig's avatar
Jan Hönig committed
508
509
            else:
                return f"(({data_type})({self._print(arg)}))"
510
        elif isinstance(expr, fast_division):
511
            return f"({self._print(expr.args[0] / expr.args[1])})"
512
        elif isinstance(expr, fast_sqrt):
513
            return f"({self._print(sp.sqrt(expr.args[0]))})"
514
515
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            return self._print(expr.args[0])
516
        elif isinstance(expr, fast_inv_sqrt):
517
            return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
Michael Kuron's avatar
Michael Kuron committed
518
        elif isinstance(expr, sp.Abs):
519
            return f"abs({self._print(expr.args[0])})"
Michael Kuron's avatar
Michael Kuron committed
520
        elif isinstance(expr, sp.Mod):
521
            if expr.args[0].is_integer and expr.args[1].is_integer:
522
                return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
Michael Kuron's avatar
Michael Kuron committed
523
            else:
524
                return f"fmod({self._print(expr.args[0])}, {self._print(expr.args[1])})"
525
        elif expr.func in infix_functions:
526
            return f"({self._print(expr.args[0])} {infix_functions[expr.func]} {self._print(expr.args[1])})"
527
        elif expr.func == int_power_of_2:
528
            return f"(1 << ({self._print(expr.args[0])}))"
529
        elif expr.func == int_div:
530
            return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
Markus Holzer's avatar
Markus Holzer committed
531
532
        elif expr.func == DivFunc:
            return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))'
533
        else:
534
            name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
535
            arg_str = ', '.join(self._print(a) for a in expr.args)
536
            return f'{name}({arg_str})'
Martin Bauer's avatar
Martin Bauer committed
537

538
539
    def _typed_number(self, number, dtype):
        res = self._print(number)
540
541
542
543
        if dtype.numpy_dtype == np.float32:
            return res + '.0f' if '.' not in res else res + 'f'
        elif dtype.numpy_dtype == np.float64:
            return res + '.0' if '.' not in res else res
544
545
546
547
548
549
550
551
        elif dtype.is_int():
            tokens = res.split('.')
            if len(tokens) == 1: 
                return res
            elif int(tokens[1]) != 0:
                raise ValueError(f"Cannot print non-integer number {res} as an integer.")
            else:
                return tokens[0]
552
553
        else:
            return res
554

555
556
557
    def _print_ConditionalFieldAccess(self, node):
        return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))

558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    def _print_Max(self, expr):
        def inner_print_max(args):
            if len(args) == 1:
                return self._print(args[0])
            half = len(args) // 2
            a = inner_print_max(args[:half])
            b = inner_print_max(args[half:])
            return f"(({a} > {b}) ? {a} : {b})"
        return inner_print_max(expr.args)

    def _print_Min(self, expr):
        def inner_print_min(args):
            if len(args) == 1:
                return self._print(args[0])
            half = len(args) // 2
            a = inner_print_min(args[:half])
            b = inner_print_min(args[half:])
            return f"(({a} < {b}) ? {a} : {b})"
        return inner_print_min(expr.args)
577

578

Martin Bauer's avatar
Martin Bauer committed
579
# noinspection PyPep8Naming
580
581
582
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])

583
584
    def __init__(self, instruction_set):
        super(VectorizedCustomSympyPrinter, self).__init__()
Martin Bauer's avatar
Martin Bauer committed
585
        self.instruction_set = instruction_set
586

Martin Bauer's avatar
Martin Bauer committed
587
588
589
590
    def _scalarFallback(self, func_name, expr, *args, **kwargs):
        expr_type = get_type_of_expression(expr)
        if type(expr_type) is not VectorType:
            return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
591
        else:
Martin Bauer's avatar
Martin Bauer committed
592
            assert self.instruction_set['width'] == expr_type.width
593
594
            return None

Markus Holzer's avatar
Markus Holzer committed
595
    def _print_Abs(self, expr):
Markus Holzer's avatar
Markus Holzer committed
596
        if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess):
Michael Kuron's avatar
Michael Kuron committed
597
            return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
Markus Holzer's avatar
Markus Holzer committed
598
599
        return super()._print_Abs(expr)

Markus Holzer's avatar
Markus Holzer committed
600
601
602
603
604
605
    def _typed_vectorized_number(self, expr, data_type):
        basic_data_type = data_type.base_type
        number = self._typed_number(expr, basic_data_type)
        instruction = 'makeVecConst'
        if basic_data_type.is_bool():
            instruction = 'makeVecConstBool'
606
        # TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
Markus Holzer's avatar
Markus Holzer committed
607
608
609
610
611
612
613
614
615
616
        elif basic_data_type.is_int():
            instruction = 'makeVecConstInt'
        return self.instruction_set[instruction].format(number, **self._kwargs)

    def _typed_vectorized_symbol(self, expr, data_type):
        if not isinstance(expr, TypedSymbol):
            raise ValueError(f'{expr} is not a TypeSymbol. It is {expr.type=}')
        basic_data_type = data_type.base_type
        symbol = self._print(expr)
        if basic_data_type != expr.dtype:
Jan Hönig's avatar
Jan Hönig committed
617
            symbol = f'(({basic_data_type})({symbol}))'
Markus Holzer's avatar
Markus Holzer committed
618
619
620
621

        instruction = 'makeVecConst'
        if basic_data_type.is_bool():
            instruction = 'makeVecConstBool'
622
        # TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
Markus Holzer's avatar
Markus Holzer committed
623
624
625
626
627
628
629
630
        elif basic_data_type.is_int():
            instruction = 'makeVecConstInt'
        return self.instruction_set[instruction].format(symbol, **self._kwargs)

    def _print_CastFunc(self, expr):
        arg, data_type = expr.args
        if type(data_type) is VectorType:
            # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
631
            assert not isinstance(arg, VectorMemoryAccess)
Markus Holzer's avatar
Markus Holzer committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
            if isinstance(arg, sp.Tuple):
                is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
                is_integer = get_type_of_expression(arg[0]) == create_type("int")
                printed_args = [self._print(a) for a in arg]
                instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
                if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
                    increments = np.array(arg)[1:] - np.array(arg)[:-1]
                    if len(set(increments)) == 1:
                        return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
                                                                           **self._kwargs)
                return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
            else:
                if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
                    return self._typed_vectorized_number(arg, data_type)
                elif isinstance(arg, TypedSymbol):
                    return self._typed_vectorized_symbol(arg, data_type)
                elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
                        and data_type == BasicType('float32'):
                    raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet')
                    # known = self.known_functions[arg.__class__.__name__.lower()]
                    # code = self._print(arg)
                    # return code.replace(known, f"{known}f")
                elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'):
                    raise NotImplementedError('Vectorizer cannot print casted aka. not double pow')
                    # known = ['sqrt', 'cbrt', 'pow']
                    # code = self._print(arg)
                    # for k in known:
                    #     if k in code:
                    #         return code.replace(k, f'{k}f')
                    # raise ValueError(f"{code} doesn't give {known=} function back.")
                else:
                    raise NotImplementedError('Vectorizer cannot cast between different datatypes')
                    # to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
                    # from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name]
                    # return self.instruction_set['cast'].format(from_type, to_type, self._print(arg))
        else:
            return self._scalarFallback('_print_Function', expr)
            # raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')

671
    def _print_Function(self, expr):
Markus Holzer's avatar
Markus Holzer committed
672
        if isinstance(expr, VectorMemoryAccess):
Michael Kuron's avatar
Michael Kuron committed
673
674
            arg, data_type, aligned, _, mask, stride = expr.args
            if stride != 1:
Markus Holzer's avatar
Markus Holzer committed
675
                return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
Martin Bauer's avatar
Martin Bauer committed
676
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
Markus Holzer's avatar
Markus Holzer committed
677
            return instruction.format(f"& {self._print(arg)}", **self._kwargs)
Markus Holzer's avatar
Markus Holzer committed
678
        elif expr.func == DivFunc:
Markus Holzer's avatar
Markus Holzer committed
679
680
681
682
683
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
                                                          **self._kwargs)
            return result
684
        elif expr.func == fast_division:
685
686
            result = self._scalarFallback('_print_Function', expr)
            if not result:
Michael Kuron's avatar
Michael Kuron committed
687
688
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]),
                                                          **self._kwargs)
689
            return result
690
        elif expr.func == fast_sqrt:
691
            return f"({self._print(sp.sqrt(expr.args[0]))})"
692
        elif expr.func == fast_inv_sqrt:
693
694
            result = self._scalarFallback('_print_Function', expr)
            if not result:
Michael Kuron's avatar
Michael Kuron committed
695
                if 'rsqrt' in self.instruction_set:
Michael Kuron's avatar
Michael Kuron committed
696
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
697
                else:
698
                    return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
Michael Kuron's avatar
Michael Kuron committed
699
700
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            instr = 'any' if isinstance(expr, vec_any) else 'all'
701
702
703
704
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
Michael Kuron's avatar
Michael Kuron committed
705
706
707
                if isinstance(expr.args[0], sp.Rel):
                    op = expr.args[0].rel_op
                    if (instr, op) in self.instruction_set:
Michael Kuron's avatar
Michael Kuron committed
708
709
710
                        return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args],
                                                                        **self._kwargs)
                return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs)
711

712
713
        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)

714
715
716
717
718
    def _print_And(self, expr):
        result = self._scalarFallback('_print_And', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
719
720
721
722
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Michael Kuron's avatar
Michael Kuron committed
723
            result = self.instruction_set['&'].format(result, item, **self._kwargs)
724
725
726
727
728
729
730
        return result

    def _print_Or(self, expr):
        result = self._scalarFallback('_print_Or', expr)
        if result:
            return result

Martin Bauer's avatar
Martin Bauer committed
731
732
733
734
        arg_strings = [self._print(a) for a in expr.args]
        assert len(arg_strings) > 0
        result = arg_strings[0]
        for item in arg_strings[1:]:
Michael Kuron's avatar
Michael Kuron committed
735
            result = self.instruction_set['|'].format(result, item, **self._kwargs)
736
737
        return result

738
    def _print_Add(self, expr, order=None):
739
740
741
742
        try:
            result = self._scalarFallback('_print_Add', expr)
        except Exception:
            result = None
743
744
        if result:
            return result
745
746
747
748
        args = expr.args

        # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
        suffix = ""
Markus Holzer's avatar
Markus Holzer committed
749
        if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
750
                or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
Markus Holzer's avatar
Markus Holzer committed
751
            dtype = set([e.dtype for e in args if type(e) is CastFunc])
752
753
            assert len(dtype) == 1
            dtype = dtype.pop()
Markus Holzer's avatar
Markus Holzer committed
754
            args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
755
756
                    for e in args]
            suffix = "int"
757
758

        summands = []
759
        for term in args:
760
            if term.func == sp.Mul:
Martin Bauer's avatar
Martin Bauer committed
761
                sign, t = self._print_Mul(term, inside_add=True)
762
763
764
765
766
767
768
769
770
771
772
773
774
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
775
            func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
Michael Kuron's avatar
Michael Kuron committed
776
            processed = func.format(processed, summand.term, **self._kwargs)
777
778
        return processed

779
    def _print_Pow(self, expr):
780
781
782
        result = self._scalarFallback('_print_Pow', expr)
        if result:
            return result
783

Michael Kuron's avatar
Michael Kuron committed
784
        one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
785

Markus Holzer's avatar
Markus Holzer committed
786
787
788
789
790
791
792
793
        if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
            exp = expr.exp.args[0]
        else:
            exp = expr.exp

        if exp.is_integer and exp.is_number and 0 < exp < 8:
            return "(" + self._print(sp.Mul(*[expr.base] * exp, evaluate=False)) + ")"
        elif exp == -1:
Michael Kuron's avatar
Michael Kuron committed
794
795
            one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
            return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
Markus Holzer's avatar
Markus Holzer committed
796
        elif exp == 0.5:
Michael Kuron's avatar
Michael Kuron committed
797
            return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
Markus Holzer's avatar
Markus Holzer committed
798
        elif exp == -0.5:
Michael Kuron's avatar
Michael Kuron committed
799
800
            root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
            return self.instruction_set['/'].format(one, root, **self._kwargs)
Markus Holzer's avatar
Markus Holzer committed
801
        elif exp.is_integer and exp.is_number and - 8 < exp < 0:
802
            return self.instruction_set['/'].format(one,
Markus Holzer's avatar
Markus Holzer committed
803
                                                    self._print(sp.Mul(*[expr.base] * (-exp), evaluate=False)),
Michael Kuron's avatar
Michael Kuron committed
804
                                                    **self._kwargs)
805
        else:
806
            raise ValueError("Generic exponential not supported: " + str(expr))
807

Martin Bauer's avatar
Martin Bauer committed
808
809
810
811
    def _print_Mul(self, expr, inside_add=False):
        # noinspection PyProtectedMember
        from sympy.core.mul import _keep_coeff

812
813
814
        result = self._scalarFallback('_print_Mul', expr)
        if result:
            return result
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842

        c, e = expr.as_coeff_Mul()
        if c < 0:
            expr = _keep_coeff(-c, e)
            sign = -1
        else:
            sign = 1

        a = []  # items in the numerator
        b = []  # items that are in the denominator (if any)

        # Gather args for numerator/denominator
        for item in expr.as_ordered_factors():
            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
                if item.exp != -1:
                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
                else:
                    b.append(sp.Pow(item.base, -item.exp))
            else:
                a.append(item)

        a = a or [S.One]

        a_str = [self._print(x) for x in a]
        b_str = [self._print(x) for x in b]

        result = a_str[0]
        for item in a_str[1:]: