astnodes.py 18.4 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.data_types import TypedSymbol, create_type, cast_func
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
6
7
8
from typing import List, Set, Optional, Union, Any

NodeOrExpr = Union['Node', sp.Expr]
9
10


11
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
12
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
13

Martin Bauer's avatar
Martin Bauer committed
14
    def __init__(self, parent: Optional['Node'] = None):
15
16
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
17
18
19
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
20
21
22
        return []

    @property
Martin Bauer's avatar
Martin Bauer committed
23
24
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
25
26
27
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
28
29
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
30
        raise NotImplementedError()
31

Martin Bauer's avatar
Martin Bauer committed
32
33
    def subs(self, *args, **kwargs) -> None:
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
34
35
36
        for a in self.args:
            a.subs(*args, **kwargs)

37
38
39
40
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
41
42
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
43
44
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
45
            if isinstance(arg, arg_type):
46
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
47
            result.update(arg.atoms(arg_type))
48
49
50
        return result


51
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
52
53
54
55
56
57
58
59
60
61
62
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

63
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'], false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
64
65
66
67
68
69
        super(Conditional, self).__init__(parent=None)

        assert condition_expr.is_Boolean or condition_expr.is_Relational
        self.conditionExpr = condition_expr

        def handle_child(c):
70
71
72
73
74
75
76
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
77
78
        self.trueBlock = handle_child(true_block)
        self.falseBlock = handle_child(false_block)
79
80
81
82
83
84

    def subs(self, *args, **kwargs):
        self.trueBlock.subs(*args, **kwargs)
        if self.falseBlock:
            self.falseBlock.subs(*args, **kwargs)
        self.conditionExpr = self.conditionExpr.subs(*args, **kwargs)
85
86
87
88
89
90
91
92
93

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
94
    def symbols_defined(self):
95
96
97
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
98
99
    def undefined_symbols(self):
        result = self.trueBlock.undefined_symbols
100
        if self.falseBlock:
Martin Bauer's avatar
Martin Bauer committed
101
            result.update(self.falseBlock.undefined_symbols)
102
103
104
105
106
107
108
109
110
111
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


112
113
114
class KernelFunction(Node):

    class Argument:
Martin Bauer's avatar
Martin Bauer committed
115
        def __init__(self, name, dtype, symbol, kernel_function_node):
Martin Bauer's avatar
Martin Bauer committed
116
            from pystencils.transformations import symbol_name_to_variable_name
117
            self.name = name
118
            self.dtype = dtype
119
120
121
122
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
Martin Bauer's avatar
Martin Bauer committed
123
            self.field_name = ""
124
            self.coordinate = None
125
            self.symbol = symbol
126
127
128
129

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
Martin Bauer's avatar
Martin Bauer committed
130
                self.field_name = name[len(Field.DATA_PREFIX):]
131
132
133
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
Martin Bauer's avatar
Martin Bauer committed
134
                self.field_name = name[len(Field.SHAPE_PREFIX):]
135
136
137
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
Martin Bauer's avatar
Martin Bauer committed
138
                self.field_name = name[len(Field.STRIDE_PREFIX):]
139

140
141
            self.field = None
            if self.isFieldArgument:
Martin Bauer's avatar
Martin Bauer committed
142
143
                field_map = {symbol_name_to_variable_name(f.name): f for f in kernel_function_node.fields_accessed}
                self.field = field_map[self.field_name]
144

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

162
163
164
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

Martin Bauer's avatar
Martin Bauer committed
165
    def __init__(self, body, ghost_layers=None, function_name="kernel", backend=""):
166
167
        super(KernelFunction, self).__init__()
        self._body = body
168
        body.parent = self
169
        self._parameters = None
Martin Bauer's avatar
Martin Bauer committed
170
        self.function_name = function_name
171
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
172
        self.compile = None
Martin Bauer's avatar
Martin Bauer committed
173
        self.ghost_layers = ghost_layers
174
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
175
        self.global_variables = set()
176
        self.backend = backend
177
178

    @property
Martin Bauer's avatar
Martin Bauer committed
179
    def symbols_defined(self):
180
181
182
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
183
    def undefined_symbols(self):
184
185
186
187
        return set()

    @property
    def parameters(self):
Martin Bauer's avatar
Martin Bauer committed
188
        self._update_parameters()
189
190
191
192
193
194
195
196
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
197
        return [self._body]
198

199
    @property
Martin Bauer's avatar
Martin Bauer committed
200
    def fields_accessed(self):
201
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
202
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
203

Martin Bauer's avatar
Martin Bauer committed
204
    def _update_parameters(self):
Martin Bauer's avatar
Martin Bauer committed
205
        undefined_symbols = self._body.undefined_symbols - self.global_variables
Martin Bauer's avatar
Martin Bauer committed
206
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefined_symbols]
207
208

        self._parameters.sort()
209

210
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
211
        self._update_parameters()
Martin Bauer's avatar
Martin Bauer committed
212
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
213
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
214
215

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
216
        self._update_parameters()
Martin Bauer's avatar
Martin Bauer committed
217
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, self.parameters)
218

219
220

class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
221
222
223
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
224
        self.parent = None
225
226
227
228
229
230
231
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
232
    def insert_front(self, node):
233
234
235
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
236
237
238
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
        idx = self._nodes.index(insert_before)
239
240

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
241
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
242
243
244
245
246
247
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
248
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
249

250
    def append(self, node):
251
252
253
254
255
256
257
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
258

Martin Bauer's avatar
Martin Bauer committed
259
    def take_child_nodes(self):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
276
    def symbols_defined(self):
277
278
        result = set()
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
279
            result.update(a.symbols_defined)
280
281
282
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
283
    def undefined_symbols(self):
284
        result = set()
Martin Bauer's avatar
Martin Bauer committed
285
        defined_symbols = set()
286
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
287
288
            result.update(a.undefined_symbols)
            defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
289
        return result - defined_symbols
290

291
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
292
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
293
294

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
295
        return "Block"
296

297
298

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
299
300
301
302
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
        self.pragmaLine = pragma_line
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
303
304
305
306
            n.parent = self

    def __repr__(self):
        return self.pragmaLine
307
308
309
310
311


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
312
313
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1):
        super(LoopOverCoordinate, self).__init__(parent=None)
314
        self.body = body
315
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
316
        self.coordinateToLoopOver = coordinate_to_loop_over
317
318
319
320
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
321
322
        self.prefixLines = []

Martin Bauer's avatar
Martin Bauer committed
323
324
    def new_loop_with_different_body(self, new_body):
        result = LoopOverCoordinate(new_body, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
325
        result.prefixLines = [l for l in self.prefixLines]
326
327
        return result

328
329
330
331
332
333
334
335
336
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

337
338
    @property
    def args(self):
339
340
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
341
342
            if hasattr(e, "args"):
                result.append(e)
343
344
        return result

345
346
347
348
349
350
351
352
353
354
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

355
    @property
Martin Bauer's avatar
Martin Bauer committed
356
357
    def symbols_defined(self):
        return {self.loop_counter_symbol}
358
359

    @property
Martin Bauer's avatar
Martin Bauer committed
360
361
    def undefined_symbols(self):
        result = self.body.undefined_symbols
362
        for possibleSymbol in [self.start, self.stop, self.step]:
363
364
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
365
        return result - {self.loop_counter_symbol}
366

Martin Bauer's avatar
Martin Bauer committed
367
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
368
369
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
370

371
    @property
Martin Bauer's avatar
Martin Bauer committed
372
373
    def loop_counter_name(self):
        return LoopOverCoordinate.get_loop_counter_name(self.coordinateToLoopOver)
Martin Bauer's avatar
Martin Bauer committed
374

375
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
376
    def is_loop_counter_symbol(symbol):
377
378
379
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
380
        if symbol.dtype != create_type('int'):
381
382
383
384
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
385
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
386
387
    def get_loop_counter_symbol(coordinate_to_loop_over):
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
388
389

    @property
Martin Bauer's avatar
Martin Bauer committed
390
391
    def loop_counter_symbol(self):
        return LoopOverCoordinate.get_loop_counter_symbol(self.coordinateToLoopOver)
392
393

    @property
Martin Bauer's avatar
Martin Bauer committed
394
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
395
396
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
397
398

    @property
Martin Bauer's avatar
Martin Bauer committed
399
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
400
        return len(self.atoms(LoopOverCoordinate)) == 0
401

402
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
403
404
405
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
406
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
407
408

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
409
410
411
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
412

413
414

class SympyAssignment(Node):
Martin Bauer's avatar
Martin Bauer committed
415
416
417
418
    def __init__(self, lhs_symbol, rhs_expr, is_const=True):
        super(SympyAssignment, self).__init__(parent=None)
        self._lhsSymbol = lhs_symbol
        self.rhs = rhs_expr
419
        self._isDeclaration = True
Martin Bauer's avatar
Martin Bauer committed
420
        is_cast = self._lhsSymbol.func == cast_func
Martin Bauer's avatar
Martin Bauer committed
421
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or is_cast:
422
            self._isDeclaration = False
Martin Bauer's avatar
Martin Bauer committed
423
        self._isConst = is_const
424
425
426
427
428
429

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
430
431
    def lhs(self, new_value):
        self._lhsSymbol = new_value
432
        self._isDeclaration = True
Martin Bauer's avatar
Martin Bauer committed
433
        is_cast = self._lhsSymbol.func == cast_func
Martin Bauer's avatar
Martin Bauer committed
434
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or is_cast:
435
436
            self._isDeclaration = False

437
    def subs(self, *args, **kwargs):
Martin Bauer's avatar
Martin Bauer committed
438
439
        self.lhs = fast_subs(self.lhs, *args, **kwargs)
        self.rhs = fast_subs(self.rhs, *args, **kwargs)
440

441
442
443
444
445
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
Martin Bauer's avatar
Martin Bauer committed
446
    def symbols_defined(self):
447
448
        if not self._isDeclaration:
            return set()
Martin Bauer's avatar
Martin Bauer committed
449
        return {self._lhsSymbol}
450
451

    @property
Martin Bauer's avatar
Martin Bauer committed
452
    def undefined_symbols(self):
453
        result = self.rhs.atoms(sp.Symbol)
454
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
455
        loop_counters = set()
456
457
458
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
459
460
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
        result.update(loop_counters)
461
462
463
464
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
465
    def is_declaration(self):
466
467
468
        return self._isDeclaration

    @property
Martin Bauer's avatar
Martin Bauer committed
469
    def is_const(self):
470
471
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
472
473
    def replace(self, child, replacement):
        if child == self.lhs:
474
475
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
476
477
478
479
480
481
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

482
483
484
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)

Martin Bauer's avatar
Martin Bauer committed
485
486
487
488
489
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
        return f"${printed_lhs} = {printed_rhs}$"

490

Martin Bauer's avatar
Martin Bauer committed
491
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
492
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
Martin Bauer's avatar
Martin Bauer committed
493
494
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
495
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
496
497
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
498
        obj.idxCoordinateValues = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
499
500
501
502
503
504
505
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

Martin Bauer's avatar
Martin Bauer committed
506
507
508
509
510
    def fast_subs(self, substitutions):
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
511
512
513
                                   self.field, self.offsets, self.idxCoordinateValues)

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
514
515
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
        return super_class_contents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
516
517

    @property
Martin Bauer's avatar
Martin Bauer committed
518
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
519
520
521
522
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
523
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
524
525
526
527
528

    def __getnewargs__(self):
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues


529
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
530
531
532
    def __init__(self, typed_symbol, size):
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
533
534
535
        self.size = size

    @property
Martin Bauer's avatar
Martin Bauer committed
536
537
    def symbols_defined(self):
        return {self.symbol}
538
539

    @property
Martin Bauer's avatar
Martin Bauer committed
540
    def undefined_symbols(self):
541
542
543
544
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
545
546
547

    @property
    def args(self):
548
        return [self.symbol]
549
550
551


class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
552
553
554
    def __init__(self, typed_symbol):
        super(TemporaryMemoryFree, self).__init__(parent=None)
        self.symbol = typed_symbol
555
556

    @property
Martin Bauer's avatar
Martin Bauer committed
557
    def symbols_defined(self):
558
559
560
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
561
    def undefined_symbols(self):
562
563
564
565
566
        return set()

    @property
    def args(self):
        return []