astnodes.py 25.2 KB
Newer Older
Stephan Seitz's avatar
Stephan Seitz committed
1
import uuid
2
3
from typing import Any, List, Optional, Sequence, Set, Union

4
import sympy as sp
5
6

from pystencils.data_types import TypedSymbol, cast_func, create_type
7
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
9
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
10
11

NodeOrExpr = Union['Node', sp.Expr]
12
13


14
class Node:
Martin Bauer's avatar
Martin Bauer committed
15
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
16

Martin Bauer's avatar
Martin Bauer committed
17
    def __init__(self, parent: Optional['Node'] = None):
18
19
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
20
21
22
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
23
        raise NotImplementedError()
24
25

    @property
Martin Bauer's avatar
Martin Bauer committed
26
27
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
28
        raise NotImplementedError()
29
30

    @property
Martin Bauer's avatar
Martin Bauer committed
31
32
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
33
        raise NotImplementedError()
34

Martin Bauer's avatar
Martin Bauer committed
35
    def subs(self, subs_dict) -> None:
Martin Bauer's avatar
Martin Bauer committed
36
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
37
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
38
            a.subs(subs_dict)
39

40
41
42
43
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
44
45
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
46
47
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
48
            if isinstance(arg, arg_type):
49
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
50
            result.update(arg.atoms(arg_type))
51
52
53
        return result


54
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
55
56
57
58
59
60
61
62
63
64
65
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
66
67
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
68
69
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
70
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
71
72

        def handle_child(c):
73
74
75
76
77
78
79
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
80
81
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
82

Martin Bauer's avatar
Martin Bauer committed
83
84
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
85
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
86
87
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
88
89
90

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
91
92
93
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
94
95
96
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
97
    def symbols_defined(self):
98
99
100
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
101
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
102
103
104
105
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
        result.update(self.condition_expr.atoms(sp.Symbol))
106
107
108
        return result

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
109
        return 'if:({!s}) '.format(self.condition_expr)
110
111

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
112
        return 'if:({!r}) '.format(self.condition_expr)
113

114
115
116
117
118
119
120
121
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

122

123
124
class KernelFunction(Node):

125
126
127
128
129
130
131
132
133
134
135
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
136

137
138
139
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
140

141
        def __repr__(self):
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
163

164
    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"):
165
166
        super(KernelFunction, self).__init__()
        self._body = body
167
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
168
        self.function_name = function_name
169
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
170
        self.ghost_layers = ghost_layers
171
172
        self._target = target
        self._backend = backend
173
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
174
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
175
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
176
177
178
179
180
181
182
183
184
185
186
187
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function

    @property
    def target(self):
        """Currently either 'cpu' or 'gpu' """
        return self._target

    @property
    def backend(self):
        """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """
        return self._backend
188
189

    @property
Martin Bauer's avatar
Martin Bauer committed
190
    def symbols_defined(self):
191
192
193
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
194
    def undefined_symbols(self):
195
196
197
198
199
200
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
201
202
203
204
205
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

206
207
    @property
    def args(self):
208
        return self._body,
209

210
    @property
211
    def fields_accessed(self) -> Set['ResolvedFieldAccess']:
212
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
213
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
214

215
216
217
218
    def fields_written(self):
        assigments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)}

219
220
221
222
223
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
224
        field_map = {f.name: f for f in self.fields_accessed}
225
226
227
228
229
230
231

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
232

233
234
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
235
236
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
237
238
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
239

240
    def __str__(self):
241
242
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
243
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
244
245

    def __repr__(self):
246
247
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
248

249
250
251
252
253
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

254

Martin Bauer's avatar
Martin Bauer committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


269
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
270
271
272
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
273
        self.parent = None
274
275
276
277
278
279
280
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
281
282
283
284
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

Martin Bauer's avatar
Martin Bauer committed
285
    def insert_front(self, node):
286
287
288
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
289
290
291
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
        idx = self._nodes.index(insert_before)
292
293

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
294
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
295
296
297
298
299
300
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
301
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
302

303
    def append(self, node):
304
305
306
307
308
309
310
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
311

Martin Bauer's avatar
Martin Bauer committed
312
    def take_child_nodes(self):
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
329
    def symbols_defined(self):
330
331
        result = set()
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
332
            result.update(a.symbols_defined)
333
334
335
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
336
    def undefined_symbols(self):
337
        result = set()
Martin Bauer's avatar
Martin Bauer committed
338
        defined_symbols = set()
339
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
340
341
            result.update(a.undefined_symbols)
            defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
342
        return result - defined_symbols
343

344
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
345
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
346
347

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
348
        return "Block"
349

350
351

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
352
353
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
354
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
355
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
356
357
358
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
359
        return self.pragma_line
360
361
362
363


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
Martin Bauer's avatar
Martin Bauer committed
364
    BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
365

Martin Bauer's avatar
Martin Bauer committed
366
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
367
        super(LoopOverCoordinate, self).__init__(parent=None)
368
        self.body = body
369
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
370
        self.coordinate_to_loop_over = coordinate_to_loop_over
371
372
373
374
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
375
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
376
        self.is_block_loop = is_block_loop
377

Martin Bauer's avatar
Martin Bauer committed
378
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
379
380
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Martin Bauer's avatar
Martin Bauer committed
381
        result.prefix_lines = [l for l in self.prefix_lines]
382
383
        return result

Martin Bauer's avatar
Martin Bauer committed
384
385
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
386
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
387
            self.start = self.start.subs(subs_dict)
388
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
389
            self.stop = self.stop.subs(subs_dict)
390
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
391
            self.step = self.step.subs(subs_dict)
392

393
394
    @property
    def args(self):
395
396
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
397
398
            if hasattr(e, "args"):
                result.append(e)
399
400
        return result

401
402
403
404
405
406
407
408
409
410
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

411
    @property
Martin Bauer's avatar
Martin Bauer committed
412
413
    def symbols_defined(self):
        return {self.loop_counter_symbol}
414
415

    @property
Martin Bauer's avatar
Martin Bauer committed
416
417
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
418
419
420
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
421
        return result - {self.loop_counter_symbol}
422

Martin Bauer's avatar
Martin Bauer committed
423
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
424
425
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
426

Martin Bauer's avatar
Martin Bauer committed
427
428
429
430
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)

431
    @property
Martin Bauer's avatar
Martin Bauer committed
432
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
433
434
435
436
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
437

438
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
439
    def is_loop_counter_symbol(symbol):
440
441
442
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
443
        if symbol.dtype != create_type('int'):
444
            return None
Martin Bauer's avatar
Martin Bauer committed
445
        coordinate = int(symbol.name[len(prefix) + 1:])
446
447
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
448
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
449
450
    def get_loop_counter_symbol(coordinate_to_loop_over):
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
451

Martin Bauer's avatar
Martin Bauer committed
452
453
454
455
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int')

456
    @property
Martin Bauer's avatar
Martin Bauer committed
457
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
458
459
460
461
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
462
463

    @property
Martin Bauer's avatar
Martin Bauer committed
464
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
465
466
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
467
468

    @property
Martin Bauer's avatar
Martin Bauer committed
469
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
470
        return len(self.atoms(LoopOverCoordinate)) == 0
471

472
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
473
474
475
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
476
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
477
478

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
479
480
481
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
482

483
484

class SympyAssignment(Node):
Martin Bauer's avatar
Martin Bauer committed
485
486
    def __init__(self, lhs_symbol, rhs_expr, is_const=True):
        super(SympyAssignment, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
487
        self._lhs_symbol = lhs_symbol
Martin Bauer's avatar
Martin Bauer committed
488
        self.rhs = rhs_expr
Martin Bauer's avatar
Martin Bauer committed
489
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
490
491
492
493
494
495
496
497
        self._is_declaration = self.__is_declaration()

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
498
499
500

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
501
        return self._lhs_symbol
502
503

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
504
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
505
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
506
        self._is_declaration = self.__is_declaration()
507

Martin Bauer's avatar
Martin Bauer committed
508
509
510
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
511

512
513
514
515
516
517
518
    def optimize(self, optimizations):
        try:
            from sympy.codegen.rewriting import optimize
            self.rhs = optimize(self.rhs, optimizations)
        except Exception:
            pass

519
520
    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
521
        return [self._lhs_symbol, self.rhs]
522
523

    @property
Martin Bauer's avatar
Martin Bauer committed
524
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
525
        if not self._is_declaration:
526
            return set()
Martin Bauer's avatar
Martin Bauer committed
527
        return {self._lhs_symbol}
528
529

    @property
Martin Bauer's avatar
Martin Bauer committed
530
    def undefined_symbols(self):
Stephan Seitz's avatar
Stephan Seitz committed
531
        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
532
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
533
        loop_counters = set()
534
535
536
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
537
538
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
539
        result.update(self._lhs_symbol.atoms(sp.Symbol))
540
541
542
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
543
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
544
        return self._is_declaration
545
546

    @property
Martin Bauer's avatar
Martin Bauer committed
547
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
548
        return self._is_const
549

Jan Hoenig's avatar
Jan Hoenig committed
550
551
    def replace(self, child, replacement):
        if child == self.lhs:
552
553
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
554
555
556
557
558
559
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

560
    def __repr__(self):
561
        return repr(self.lhs) + " ← " + repr(self.rhs)
562

Martin Bauer's avatar
Martin Bauer committed
563
564
565
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
Martin Bauer's avatar
Martin Bauer committed
566
        return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
Martin Bauer's avatar
Martin Bauer committed
567

568

Martin Bauer's avatar
Martin Bauer committed
569
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
570
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
571
572
        if not isinstance(base, sp.IndexedBase):
            base = sp.IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
573
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
574
575
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
576
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
577
578
579
580
581
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
582
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
583

Martin Bauer's avatar
Martin Bauer committed
584
585
586
587
588
    def fast_subs(self, substitutions):
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
589
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
590
591

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
592
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
593
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
594
595

    @property
Martin Bauer's avatar
Martin Bauer committed
596
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
597
598
599
600
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
601
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
602
603

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
604
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
605
606


607
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
608
609
610
611
612
613
614
615
616
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
617

Martin Bauer's avatar
Martin Bauer committed
618
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
619
620
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
621
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
622
623
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
624
625

    @property
Martin Bauer's avatar
Martin Bauer committed
626
627
    def symbols_defined(self):
        return {self.symbol}
628
629

    @property
Martin Bauer's avatar
Martin Bauer committed
630
    def undefined_symbols(self):
631
632
633
634
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
635
636
637

    @property
    def args(self):
638
        return [self.symbol]
639

Martin Bauer's avatar
Martin Bauer committed
640
641
642
643
644
645
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

646
647

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
648
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
649
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
650
651
652
653
654
655
656
657
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
658
659

    @property
Martin Bauer's avatar
Martin Bauer committed
660
    def symbols_defined(self):
661
662
663
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
664
    def undefined_symbols(self):
665
666
667
668
669
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
670
671
672
673
674


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
675
676
677
678
679
680
681
682


class DestructuringBindingsForFieldClass(Node):
    """
    Defines all variables needed for describing a field (shape, pointer, strides)
    """
    CLASS_TO_MEMBER_DICT = {
        FieldPointerSymbol: "data",
683
684
        FieldShapeSymbol: "shape[%i]",
        FieldStrideSymbol: "stride[%i]"
685
    }
686
    CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
687
688
689
690
691

    @property
    def fields_accessed(self) -> Set['ResolvedFieldAccess']:
        """Set of Field instances: fields which are accessed inside this kernel function"""
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
692
693
694

    def __init__(self, body):
        super(DestructuringBindingsForFieldClass, self).__init__()
695
        self.headers = ['<PyStencilsField.h>']
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        self.body = body

    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
        return set()

    @property
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
        undefined_field_symbols = {s for s in self.body.undefined_symbols
                                   if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
        return undefined_field_symbols

    @property
    def undefined_symbols(self) -> Set[sp.Symbol]:
712
        field_map = {f.name: f for f in self.fields_accessed}
713
714
715
        undefined_field_symbols = self.symbols_defined
        corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
        corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
716
        return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
Stephan Seitz's avatar
Stephan Seitz committed
717
                for f in corresponding_field_names} | (self.body.undefined_symbols - undefined_field_symbols)
718
719
720
721
722
723
724
725
726
727
728

    def subs(self, subs_dict) -> None:
        """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
        self.body.subs(subs_dict)

    @property
    def func(self):
        return self.__class__

    def atoms(self, arg_type) -> Set[Any]:
        return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
Stephan Seitz's avatar
Stephan Seitz committed
729
730
731
732


def get_dummy_symbol(dtype='bool'):
    return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778


class SourceCodeComment(Node):
    def __init__(self, text):
        self.text = text

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return "/* " + self.text + " */"

    def __repr__(self):
        return self.__str__()


class EmptyLine(Node):
    def __init__(self):
        pass

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return ""

    def __repr__(self):
        return self.__str__()