astnodes.py 24.2 KB
Newer Older
Stephan Seitz's avatar
Stephan Seitz committed
1
import uuid
2
3
from typing import Any, List, Optional, Sequence, Set, Union

4
import sympy as sp
5
6

from pystencils.data_types import TypedSymbol, cast_func, create_type
7
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
10
11

NodeOrExpr = Union['Node', sp.Expr]
12
13


14
class Node:
Martin Bauer's avatar
Martin Bauer committed
15
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
16

Martin Bauer's avatar
Martin Bauer committed
17
    def __init__(self, parent: Optional['Node'] = None):
18
19
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
20
21
22
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
23
        raise NotImplementedError()
24
25

    @property
Martin Bauer's avatar
Martin Bauer committed
26
27
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
28
        raise NotImplementedError()
29
30

    @property
Martin Bauer's avatar
Martin Bauer committed
31
32
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
33
        raise NotImplementedError()
34

Martin Bauer's avatar
Martin Bauer committed
35
    def subs(self, subs_dict) -> None:
Martin Bauer's avatar
Martin Bauer committed
36
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
37
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
38
            a.subs(subs_dict)
39

40
41
42
43
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
44
45
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
46
47
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
48
            if isinstance(arg, arg_type):
49
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
50
            result.update(arg.atoms(arg_type))
51
52
53
        return result


54
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
55
56
57
58
59
60
61
62
63
64
65
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
66
67
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
68
69
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
70
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
71
72

        def handle_child(c):
73
74
75
76
77
78
79
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
80
81
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
82

Martin Bauer's avatar
Martin Bauer committed
83
84
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
85
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
86
87
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
88
89
90

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
91
92
93
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
94
95
96
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
97
    def symbols_defined(self):
98
99
100
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
101
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
        result.update(self.condition_expr.atoms(sp.Symbol))
106
107
108
        return result

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
109
        return 'if:({!s}) '.format(self.condition_expr)
110
111

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
112
        return 'if:({!r}) '.format(self.condition_expr)
113

114
115
116
117
118
119
120
121
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

122

123
124
class KernelFunction(Node):

125
126
127
128
129
130
131
132
133
134
135
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
136

137
138
139
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
140

141
        def __repr__(self):
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
163

164
    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"):
165
166
        super(KernelFunction, self).__init__()
        self._body = body
167
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
168
        self.function_name = function_name
169
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
170
        self.ghost_layers = ghost_layers
171
172
        self._target = target
        self._backend = backend
173
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
174
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
175
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
176
177
178
179
180
181
182
183
184
185
186
187
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function

    @property
    def target(self):
        """Currently either 'cpu' or 'gpu' """
        return self._target

    @property
    def backend(self):
        """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """
        return self._backend
188
189

    @property
Martin Bauer's avatar
Martin Bauer committed
190
    def symbols_defined(self):
191
192
193
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
194
    def undefined_symbols(self):
195
196
197
198
199
200
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
201
202
203
204
205
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

206
207
    @property
    def args(self):
208
        return self._body,
209

210
    @property
211
    def fields_accessed(self) -> Set['ResolvedFieldAccess']:
212
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
213
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
214

215
216
217
218
    def fields_written(self):
        assigments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)}

219
220
221
222
223
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
224
        field_map = {f.name: f for f in self.fields_accessed}
225
226
227
228
229
230
231

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
232

233
234
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
235
236
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
237
238
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
239

240
    def __str__(self):
241
242
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
243
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
244
245

    def __repr__(self):
246
247
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
248

249
250
251
252
253
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

254

Martin Bauer's avatar
Martin Bauer committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


269
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
270
271
272
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
273
        self.parent = None
274
275
276
277
278
279
280
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
281
282
283
284
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

Martin Bauer's avatar
Martin Bauer committed
285
    def insert_front(self, node):
286
287
288
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
289
290
291
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
        idx = self._nodes.index(insert_before)
292
293

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
294
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
295
296
297
298
299
300
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
301
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
302

303
    def append(self, node):
304
305
306
307
308
309
310
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
311

Martin Bauer's avatar
Martin Bauer committed
312
    def take_child_nodes(self):
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
329
    def symbols_defined(self):
330
331
        result = set()
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
332
            result.update(a.symbols_defined)
333
334
335
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
336
    def undefined_symbols(self):
337
        result = set()
Martin Bauer's avatar
Martin Bauer committed
338
        defined_symbols = set()
339
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
340
341
            result.update(a.undefined_symbols)
            defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
342
        return result - defined_symbols
343

344
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
345
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
346
347

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
348
        return "Block"
349

350
351

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
352
353
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
354
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
355
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
356
357
358
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
359
        return self.pragma_line
360
361
362
363


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
Martin Bauer's avatar
Martin Bauer committed
364
    BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
365

Martin Bauer's avatar
Martin Bauer committed
366
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
367
        super(LoopOverCoordinate, self).__init__(parent=None)
368
        self.body = body
369
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
370
        self.coordinate_to_loop_over = coordinate_to_loop_over
371
372
373
374
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
375
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
376
        self.is_block_loop = is_block_loop
377

Martin Bauer's avatar
Martin Bauer committed
378
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
379
380
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Martin Bauer's avatar
Martin Bauer committed
381
        result.prefix_lines = [l for l in self.prefix_lines]
382
383
        return result

Martin Bauer's avatar
Martin Bauer committed
384
385
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
386
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
387
            self.start = self.start.subs(subs_dict)
388
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
389
            self.stop = self.stop.subs(subs_dict)
390
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
391
            self.step = self.step.subs(subs_dict)
392

393
394
    @property
    def args(self):
395
396
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
397
398
            if hasattr(e, "args"):
                result.append(e)
399
400
        return result

401
402
403
404
405
406
407
408
409
410
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

411
    @property
Martin Bauer's avatar
Martin Bauer committed
412
413
    def symbols_defined(self):
        return {self.loop_counter_symbol}
414
415

    @property
Martin Bauer's avatar
Martin Bauer committed
416
417
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
418
419
420
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
421
        return result - {self.loop_counter_symbol}
422

Martin Bauer's avatar
Martin Bauer committed
423
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
424
425
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
426

Martin Bauer's avatar
Martin Bauer committed
427
428
429
430
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)

431
    @property
Martin Bauer's avatar
Martin Bauer committed
432
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
433
434
435
436
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
437

438
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
439
    def is_loop_counter_symbol(symbol):
440
441
442
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
443
        if symbol.dtype != create_type('int'):
444
            return None
Martin Bauer's avatar
Martin Bauer committed
445
        coordinate = int(symbol.name[len(prefix) + 1:])
446
447
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
448
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
449
450
    def get_loop_counter_symbol(coordinate_to_loop_over):
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
451

Martin Bauer's avatar
Martin Bauer committed
452
453
454
455
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int')

456
    @property
Martin Bauer's avatar
Martin Bauer committed
457
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
458
459
460
461
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
462
463

    @property
Martin Bauer's avatar
Martin Bauer committed
464
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
465
466
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
467
468

    @property
Martin Bauer's avatar
Martin Bauer committed
469
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
470
        return len(self.atoms(LoopOverCoordinate)) == 0
471

472
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
473
474
475
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
476
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
477
478

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
479
480
481
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
482

483
484

class SympyAssignment(Node):
Martin Bauer's avatar
Martin Bauer committed
485
486
    def __init__(self, lhs_symbol, rhs_expr, is_const=True):
        super(SympyAssignment, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
487
        self._lhs_symbol = lhs_symbol
Martin Bauer's avatar
Martin Bauer committed
488
        self.rhs = rhs_expr
Martin Bauer's avatar
Martin Bauer committed
489
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
490
491
492
493
494
495
496
497
        self._is_declaration = self.__is_declaration()

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
498
499
500

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
501
        return self._lhs_symbol
502
503

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
504
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
505
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
506
        self._is_declaration = self.__is_declaration()
507

Martin Bauer's avatar
Martin Bauer committed
508
509
510
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
511

512
513
    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
514
        return [self._lhs_symbol, self.rhs]
515
516

    @property
Martin Bauer's avatar
Martin Bauer committed
517
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
518
        if not self._is_declaration:
519
            return set()
Martin Bauer's avatar
Martin Bauer committed
520
        return {self._lhs_symbol}
521
522

    @property
Martin Bauer's avatar
Martin Bauer committed
523
    def undefined_symbols(self):
524
        result = self.rhs.atoms(sp.Symbol)
525
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
526
        loop_counters = set()
527
528
529
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
530
531
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
532
        result.update(self._lhs_symbol.atoms(sp.Symbol))
533
534
535
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
536
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
537
        return self._is_declaration
538
539

    @property
Martin Bauer's avatar
Martin Bauer committed
540
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
541
        return self._is_const
542

Jan Hoenig's avatar
Jan Hoenig committed
543
544
    def replace(self, child, replacement):
        if child == self.lhs:
545
546
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
547
548
549
550
551
552
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

553
    def __repr__(self):
554
        return repr(self.lhs) + " ← " + repr(self.rhs)
555

Martin Bauer's avatar
Martin Bauer committed
556
557
558
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
Martin Bauer's avatar
Martin Bauer committed
559
        return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
Martin Bauer's avatar
Martin Bauer committed
560

561

Martin Bauer's avatar
Martin Bauer committed
562
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
563
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
564
565
        if not isinstance(base, sp.IndexedBase):
            base = sp.IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
566
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
567
568
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
569
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
570
571
572
573
574
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
575
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
576

Martin Bauer's avatar
Martin Bauer committed
577
578
579
580
581
    def fast_subs(self, substitutions):
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
582
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
583
584

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
585
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
586
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
587
588

    @property
Martin Bauer's avatar
Martin Bauer committed
589
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
590
591
592
593
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
594
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
595
596

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
597
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
598
599


600
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
601
602
603
604
605
606
607
608
609
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
610

Martin Bauer's avatar
Martin Bauer committed
611
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
612
613
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
614
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
615
616
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
617
618

    @property
Martin Bauer's avatar
Martin Bauer committed
619
620
    def symbols_defined(self):
        return {self.symbol}
621
622

    @property
Martin Bauer's avatar
Martin Bauer committed
623
    def undefined_symbols(self):
624
625
626
627
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
628
629
630

    @property
    def args(self):
631
        return [self.symbol]
632

Martin Bauer's avatar
Martin Bauer committed
633
634
635
636
637
638
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

639
640

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
641
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
642
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
643
644
645
646
647
648
649
650
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
651
652

    @property
Martin Bauer's avatar
Martin Bauer committed
653
    def symbols_defined(self):
654
655
656
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
657
    def undefined_symbols(self):
658
659
660
661
662
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
663
664
665
666
667


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
668
669
670
671
672
673
674
675


class DestructuringBindingsForFieldClass(Node):
    """
    Defines all variables needed for describing a field (shape, pointer, strides)
    """
    CLASS_TO_MEMBER_DICT = {
        FieldPointerSymbol: "data",
676
677
        FieldShapeSymbol: "shape[%i]",
        FieldStrideSymbol: "stride[%i]"
678
    }
679
    CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
680
681
682
683
684

    @property
    def fields_accessed(self) -> Set['ResolvedFieldAccess']:
        """Set of Field instances: fields which are accessed inside this kernel function"""
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
685
686
687

    def __init__(self, body):
        super(DestructuringBindingsForFieldClass, self).__init__()
688
        self.headers = ['<PyStencilsField.h>']
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        self.body = body

    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
        return set()

    @property
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
        undefined_field_symbols = {s for s in self.body.undefined_symbols
                                   if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
        return undefined_field_symbols

    @property
    def undefined_symbols(self) -> Set[sp.Symbol]:
705
        field_map = {f.name: f for f in self.fields_accessed}
706
707
708
        undefined_field_symbols = self.symbols_defined
        corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
        corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
709
        return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
710
                for f in corresponding_field_names} | \
711
712
713
714
715
716
717
718
719
720
721
722
            (self.body.undefined_symbols - undefined_field_symbols)

    def subs(self, subs_dict) -> None:
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
        self.body.subs(subs_dict)

    @property
    def func(self):
        return self.__class__

    def atoms(self, arg_type) -> Set[Any]:
        return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
Stephan Seitz's avatar
Stephan Seitz committed
723
724
725
726


def get_dummy_symbol(dtype='bool'):
    return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))