astnodes.py 20.1 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.data_types import TypedSymbol, create_type, cast_func
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
6
7
8
from typing import List, Set, Optional, Union, Any

NodeOrExpr = Union['Node', sp.Expr]
9
10


11
class Node:
Martin Bauer's avatar
Martin Bauer committed
12
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
13

Martin Bauer's avatar
Martin Bauer committed
14
    def __init__(self, parent: Optional['Node'] = None):
15
16
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
17
18
19
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
20
        raise NotImplementedError()
21
22

    @property
Martin Bauer's avatar
Martin Bauer committed
23
24
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
25
        raise NotImplementedError()
26
27

    @property
Martin Bauer's avatar
Martin Bauer committed
28
29
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
30
        raise NotImplementedError()
31

Martin Bauer's avatar
Martin Bauer committed
32
    def subs(self, subs_dict) -> None:
Martin Bauer's avatar
Martin Bauer committed
33
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
34
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
35
            a.subs(subs_dict)
36

37
38
39
40
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
41
42
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
43
44
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
45
            if isinstance(arg, arg_type):
46
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
47
            result.update(arg.atoms(arg_type))
48
49
50
        return result


51
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
52
53
54
55
56
57
58
59
60
61
62
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
63
64
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
65
66
67
        super(Conditional, self).__init__(parent=None)

        assert condition_expr.is_Boolean or condition_expr.is_Relational
Martin Bauer's avatar
Martin Bauer committed
68
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
69
70

        def handle_child(c):
71
72
73
74
75
76
77
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
78
79
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
80

Martin Bauer's avatar
Martin Bauer committed
81
82
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
83
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
84
85
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
86
87
88

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
89
90
91
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
92
93
94
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
95
    def symbols_defined(self):
96
97
98
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
99
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
100
101
102
103
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
        result.update(self.condition_expr.atoms(sp.Symbol))
104
105
106
        return result

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
107
        return 'if:({!s}) '.format(self.condition_expr)
108
109

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
110
        return 'if:({!r}) '.format(self.condition_expr)
111

112
113
114
115
116
117
118
119
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

120

121
122
123
class KernelFunction(Node):

    class Argument:
Martin Bauer's avatar
Martin Bauer committed
124
        def __init__(self, name, dtype, symbol, kernel_function_node):
Martin Bauer's avatar
Martin Bauer committed
125
            from pystencils.transformations import symbol_name_to_variable_name
126
            self.name = name
127
            self.dtype = dtype
Martin Bauer's avatar
Martin Bauer committed
128
129
130
131
            self.is_field_ptr_argument = False
            self.is_field_shape_argument = False
            self.is_field_stride_argument = False
            self.is_field_argument = False
Martin Bauer's avatar
Martin Bauer committed
132
            self.field_name = ""
133
            self.coordinate = None
134
            self.symbol = symbol
135
136

            if name.startswith(Field.DATA_PREFIX):
Martin Bauer's avatar
Martin Bauer committed
137
138
                self.is_field_ptr_argument = True
                self.is_field_argument = True
Martin Bauer's avatar
Martin Bauer committed
139
                self.field_name = name[len(Field.DATA_PREFIX):]
140
            elif name.startswith(Field.SHAPE_PREFIX):
Martin Bauer's avatar
Martin Bauer committed
141
142
                self.is_field_shape_argument = True
                self.is_field_argument = True
Martin Bauer's avatar
Martin Bauer committed
143
                self.field_name = name[len(Field.SHAPE_PREFIX):]
144
            elif name.startswith(Field.STRIDE_PREFIX):
Martin Bauer's avatar
Martin Bauer committed
145
146
                self.is_field_stride_argument = True
                self.is_field_argument = True
Martin Bauer's avatar
Martin Bauer committed
147
                self.field_name = name[len(Field.STRIDE_PREFIX):]
148

149
            self.field = None
Martin Bauer's avatar
Martin Bauer committed
150
            if self.is_field_argument:
Martin Bauer's avatar
Martin Bauer committed
151
152
                field_map = {symbol_name_to_variable_name(f.name): f for f in kernel_function_node.fields_accessed}
                self.field = field_map[self.field_name]
153

154
155
        def __lt__(self, other):
            def score(l):
Martin Bauer's avatar
Martin Bauer committed
156
                if l.is_field_ptr_argument:
157
                    return -4
Martin Bauer's avatar
Martin Bauer committed
158
                elif l.is_field_shape_argument:
159
                    return -3
Martin Bauer's avatar
Martin Bauer committed
160
                elif l.is_field_stride_argument:
161
162
163
164
165
166
167
168
169
170
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

171
172
173
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

Martin Bauer's avatar
Martin Bauer committed
174
    def __init__(self, body, ghost_layers=None, function_name="kernel", backend=""):
175
176
        super(KernelFunction, self).__init__()
        self._body = body
177
        body.parent = self
178
        self._parameters = None
Martin Bauer's avatar
Martin Bauer committed
179
        self.function_name = function_name
180
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
181
        self.compile = None
Martin Bauer's avatar
Martin Bauer committed
182
        self.ghost_layers = ghost_layers
183
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
184
        self.global_variables = set()
185
        self.backend = backend
Martin Bauer's avatar
Martin Bauer committed
186
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
187
188

    @property
Martin Bauer's avatar
Martin Bauer committed
189
    def symbols_defined(self):
190
191
192
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
193
    def undefined_symbols(self):
194
195
196
197
        return set()

    @property
    def parameters(self):
Martin Bauer's avatar
Martin Bauer committed
198
        self._update_parameters()
199
200
201
202
203
204
205
206
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
207
        return [self._body]
208

209
    @property
Martin Bauer's avatar
Martin Bauer committed
210
    def fields_accessed(self):
211
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
212
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
213

Martin Bauer's avatar
Martin Bauer committed
214
    def _update_parameters(self):
Martin Bauer's avatar
Martin Bauer committed
215
        undefined_symbols = self._body.undefined_symbols - self.global_variables
Martin Bauer's avatar
Martin Bauer committed
216
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefined_symbols]
217
218

        self._parameters.sort()
219

220
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
221
        self._update_parameters()
Martin Bauer's avatar
Martin Bauer committed
222
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
223
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
224
225

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
226
        self._update_parameters()
Martin Bauer's avatar
Martin Bauer committed
227
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, self.parameters)
228

229
230

class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
231
232
233
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
234
        self.parent = None
235
236
237
238
239
240
241
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
242
243
244
245
246
247
248
249
250
251
252
253
    def subs(self, subs_dict) -> None:
        new_args = []
        for a in self.args:
            if isinstance(a, SympyAssignment) and a.is_declaration and a.rhs in subs_dict.keys():
                subs_dict[a.lhs] = subs_dict[a.rhs]
            else:
                new_args.append(a)
        self._nodes = new_args

        for a in self.args:
            a.subs(subs_dict)

Martin Bauer's avatar
Martin Bauer committed
254
    def insert_front(self, node):
255
256
257
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
258
259
260
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
        idx = self._nodes.index(insert_before)
261
262

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
263
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
264
265
266
267
268
269
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
270
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
271

272
    def append(self, node):
273
274
275
276
277
278
279
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
280

Martin Bauer's avatar
Martin Bauer committed
281
    def take_child_nodes(self):
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
298
    def symbols_defined(self):
299
300
        result = set()
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
301
            result.update(a.symbols_defined)
302
303
304
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
305
    def undefined_symbols(self):
306
        result = set()
Martin Bauer's avatar
Martin Bauer committed
307
        defined_symbols = set()
308
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
309
310
            result.update(a.undefined_symbols)
            defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
311
        return result - defined_symbols
312

313
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
314
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
315
316

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
317
        return "Block"
318

319
320

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
321
322
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
323
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
324
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
325
326
327
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
328
        return self.pragma_line
329
330
331
332
333


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
334
335
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1):
        super(LoopOverCoordinate, self).__init__(parent=None)
336
        self.body = body
337
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
338
        self.coordinate_to_loop_over = coordinate_to_loop_over
339
340
341
342
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
343
        self.prefix_lines = []
344

Martin Bauer's avatar
Martin Bauer committed
345
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
346
347
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, self.step)
        result.prefix_lines = [l for l in self.prefix_lines]
348
349
        return result

Martin Bauer's avatar
Martin Bauer committed
350
351
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
352
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
353
            self.start = self.start.subs(subs_dict)
354
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
355
            self.stop = self.stop.subs(subs_dict)
356
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
357
            self.step = self.step.subs(subs_dict)
358

359
360
    @property
    def args(self):
361
362
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
363
364
            if hasattr(e, "args"):
                result.append(e)
365
366
        return result

367
368
369
370
371
372
373
374
375
376
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

377
    @property
Martin Bauer's avatar
Martin Bauer committed
378
379
    def symbols_defined(self):
        return {self.loop_counter_symbol}
380
381

    @property
Martin Bauer's avatar
Martin Bauer committed
382
383
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
384
385
386
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
387
        return result - {self.loop_counter_symbol}
388

Martin Bauer's avatar
Martin Bauer committed
389
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
390
391
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
392

393
    @property
Martin Bauer's avatar
Martin Bauer committed
394
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
395
        return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
396

397
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
398
    def is_loop_counter_symbol(symbol):
399
400
401
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
402
        if symbol.dtype != create_type('int'):
403
            return None
Martin Bauer's avatar
Martin Bauer committed
404
        coordinate = int(symbol.name[len(prefix) + 1:])
405
406
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
407
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
408
409
    def get_loop_counter_symbol(coordinate_to_loop_over):
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
410
411

    @property
Martin Bauer's avatar
Martin Bauer committed
412
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
413
        return LoopOverCoordinate.get_loop_counter_symbol(self.coordinate_to_loop_over)
414
415

    @property
Martin Bauer's avatar
Martin Bauer committed
416
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
417
418
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
419
420

    @property
Martin Bauer's avatar
Martin Bauer committed
421
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
422
        return len(self.atoms(LoopOverCoordinate)) == 0
423

424
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
425
426
427
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
428
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
429
430

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
431
432
433
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
434

435
436

class SympyAssignment(Node):
Martin Bauer's avatar
Martin Bauer committed
437
438
    def __init__(self, lhs_symbol, rhs_expr, is_const=True):
        super(SympyAssignment, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
439
        self._lhs_symbol = lhs_symbol
Martin Bauer's avatar
Martin Bauer committed
440
        self.rhs = rhs_expr
Martin Bauer's avatar
Martin Bauer committed
441
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
442
443
444
445
446
447
448
449
        self._is_declaration = self.__is_declaration()

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
450
451
452

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
453
        return self._lhs_symbol
454
455

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
456
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
457
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
458
        self._is_declaration = self.__is_declaration()
459

Martin Bauer's avatar
Martin Bauer committed
460
461
462
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
463

464
465
    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
466
        return [self._lhs_symbol, self.rhs]
467
468

    @property
Martin Bauer's avatar
Martin Bauer committed
469
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
470
        if not self._is_declaration:
471
            return set()
Martin Bauer's avatar
Martin Bauer committed
472
        return {self._lhs_symbol}
473
474

    @property
Martin Bauer's avatar
Martin Bauer committed
475
    def undefined_symbols(self):
476
        result = self.rhs.atoms(sp.Symbol)
477
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
478
        loop_counters = set()
479
480
481
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
482
483
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
484
        result.update(self._lhs_symbol.atoms(sp.Symbol))
485
486
487
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
488
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
489
        return self._is_declaration
490
491

    @property
Martin Bauer's avatar
Martin Bauer committed
492
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
493
        return self._is_const
494

Jan Hoenig's avatar
Jan Hoenig committed
495
496
    def replace(self, child, replacement):
        if child == self.lhs:
497
498
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
499
500
501
502
503
504
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

505
    def __repr__(self):
506
        return repr(self.lhs) + " ← " + repr(self.rhs)
507

Martin Bauer's avatar
Martin Bauer committed
508
509
510
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
511
        return "${printed_lhs} \leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
Martin Bauer's avatar
Martin Bauer committed
512

513

Martin Bauer's avatar
Martin Bauer committed
514
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
515
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
Martin Bauer's avatar
Martin Bauer committed
516
517
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
518
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
519
520
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
521
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
522
523
524
525
526
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
527
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
528

Martin Bauer's avatar
Martin Bauer committed
529
530
531
532
533
    def fast_subs(self, substitutions):
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
534
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
535
536

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
537
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
538
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
539
540

    @property
Martin Bauer's avatar
Martin Bauer committed
541
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
542
543
544
545
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
546
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
547
548

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
549
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
550
551


552
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
553
554
555
556
557
558
559
560
561
562
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
563
564
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
565
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
566
567
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
568
569

    @property
Martin Bauer's avatar
Martin Bauer committed
570
571
    def symbols_defined(self):
        return {self.symbol}
572
573

    @property
Martin Bauer's avatar
Martin Bauer committed
574
    def undefined_symbols(self):
575
576
577
578
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
579
580
581

    @property
    def args(self):
582
        return [self.symbol]
583

Martin Bauer's avatar
Martin Bauer committed
584
585
586
587
588
589
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

590
591

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
592
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
593
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
594
595
596
597
598
599
600
601
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
602
603

    @property
Martin Bauer's avatar
Martin Bauer committed
604
    def symbols_defined(self):
605
606
607
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
608
    def undefined_symbols(self):
609
610
611
612
613
        return set()

    @property
    def args(self):
        return []