astnodes.py 25.8 KB
Newer Older
1
2
import collections.abc
import itertools
Stephan Seitz's avatar
Stephan Seitz committed
3
import uuid
4
5
from typing import Any, List, Optional, Sequence, Set, Union

6
import sympy as sp
7

8
import pystencils
9
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
10
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
13
14

NodeOrExpr = Union['Node', sp.Expr]
15
16


17
class Node:
Martin Bauer's avatar
Martin Bauer committed
18
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
19

Martin Bauer's avatar
Martin Bauer committed
20
    def __init__(self, parent: Optional['Node'] = None):
21
22
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
23
24
25
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
26
        raise NotImplementedError()
27
28

    @property
Martin Bauer's avatar
Martin Bauer committed
29
30
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
31
        raise NotImplementedError()
32
33

    @property
Martin Bauer's avatar
Martin Bauer committed
34
35
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
36
        raise NotImplementedError()
37

Martin Bauer's avatar
Martin Bauer committed
38
    def subs(self, subs_dict) -> None:
39
        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
40
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
41
            a.subs(subs_dict)
42

43
44
45
46
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
47
48
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
49
50
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
51
            if isinstance(arg, arg_type):
52
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
53
            result.update(arg.atoms(arg_type))
54
55
56
        return result


57
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
58
59
60
61
62
63
64
65
66
67
68
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
69
70
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
71
72
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
73
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
74
75

        def handle_child(c):
76
77
78
79
80
81
82
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
83
84
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
85

Martin Bauer's avatar
Martin Bauer committed
86
87
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
88
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
89
90
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
91
92
93

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
94
95
96
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
97
98
99
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
100
    def symbols_defined(self):
101
102
103
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
104
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
105
106
107
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
108
109
        if hasattr(self.condition_expr, 'atoms'):
            result.update(self.condition_expr.atoms(sp.Symbol))
110
111
112
        return result

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
113
        return 'if:({!s}) '.format(self.condition_expr)
114
115

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
116
        return 'if:({!r}) '.format(self.condition_expr)
117

118
119
120
121
122
123
124
125
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

126

127
128
class KernelFunction(Node):

129
130
131
132
133
134
135
136
137
138
139
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
140

141
142
143
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
144

145
        def __repr__(self):
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
167

168
    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"):
169
170
        super(KernelFunction, self).__init__()
        self._body = body
171
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
172
        self.function_name = function_name
173
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
174
        self.ghost_layers = ghost_layers
175
176
        self._target = target
        self._backend = backend
177
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
178
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
179
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
180
181
182
183
184
185
186
187
188
189
190
191
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function

    @property
    def target(self):
        """Currently either 'cpu' or 'gpu' """
        return self._target

    @property
    def backend(self):
        """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """
        return self._backend
192
193

    @property
Martin Bauer's avatar
Martin Bauer committed
194
    def symbols_defined(self):
195
196
197
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
198
    def undefined_symbols(self):
199
200
201
202
203
204
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
205
206
207
208
209
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

210
211
    @property
    def args(self):
212
        return self._body,
213

214
    @property
215
    def fields_accessed(self) -> Set[Field]:
216
        """Set of Field instances: fields which are accessed inside this kernel function"""
217
218
        from pystencils.interpolation_astnodes import InterpolatorAccess
        return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
219

220
    @property
221
    def fields_written(self) -> Set[Field]:
222
223
224
225
        assignments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}

    @property
226
    def fields_read(self) -> Set[Field]:
227
228
229
        assignments = self.atoms(SympyAssignment)
        return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
                                                         for a in assignments))
230

231
232
233
234
235
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
236
        field_map = {f.name: f for f in self.fields_accessed}
237
238
239
240
241
242
243

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
244

245
246
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
247
248
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
249
250
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
251

252
    def __str__(self):
253
254
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
255
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
256
257

    def __repr__(self):
258
259
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
260

261
262
263
264
265
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

266

Martin Bauer's avatar
Martin Bauer committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


281
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
282
283
284
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
285
        self.parent = None
286
287
288
289
290
291
292
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
293
294
295
296
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

297
298
299
300
    def fast_subs(self, subs_dict, skip=None):
        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
        return self

Martin Bauer's avatar
Martin Bauer committed
301
    def insert_front(self, node):
302
303
304
305
306
307
308
309
310
        if isinstance(node, collections.abc.Iterable):
            node = list(node)
            for n in node:
                n.parent = self

            self._nodes = node + self._nodes
        else:
            node.parent = self
            self._nodes.insert(0, node)
311

Martin Bauer's avatar
Martin Bauer committed
312
313
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
314
        assert self._nodes.count(insert_before) == 1
Martin Bauer's avatar
Martin Bauer committed
315
        idx = self._nodes.index(insert_before)
316
317

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
318
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
319
320
321
322
323
324
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
325
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
326

327
    def append(self, node):
328
329
330
331
332
333
334
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
335

Martin Bauer's avatar
Martin Bauer committed
336
    def take_child_nodes(self):
337
338
339
340
341
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
342
        assert self._nodes.count(child) == 1
343
344
345
346
347
348
349
350
351
352
353
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
354
    def symbols_defined(self):
355
356
        result = set()
        for a in self.args:
357
358
359
360
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
            else:
                result.update(a.symbols_defined)
361
362
363
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
364
    def undefined_symbols(self):
365
        result = set()
Martin Bauer's avatar
Martin Bauer committed
366
        defined_symbols = set()
367
        for a in self.args:
368
369
370
371
372
373
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
                defined_symbols.update({a.lhs})
            else:
                result.update(a.undefined_symbols)
                defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
374
        return result - defined_symbols
375

376
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
377
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
378
379

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
380
        return "Block"
381

382
383

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
384
385
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
386
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
387
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
388
389
390
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
391
        return self.pragma_line
392
393
394
395


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
Martin Bauer's avatar
Martin Bauer committed
396
    BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
397

Martin Bauer's avatar
Martin Bauer committed
398
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
399
        super(LoopOverCoordinate, self).__init__(parent=None)
400
        self.body = body
401
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
402
        self.coordinate_to_loop_over = coordinate_to_loop_over
403
404
405
406
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
407
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
408
        self.is_block_loop = is_block_loop
409

Martin Bauer's avatar
Martin Bauer committed
410
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
411
412
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Martin Bauer's avatar
Martin Bauer committed
413
        result.prefix_lines = [l for l in self.prefix_lines]
414
415
        return result

Martin Bauer's avatar
Martin Bauer committed
416
417
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
418
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
419
            self.start = self.start.subs(subs_dict)
420
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
421
            self.stop = self.stop.subs(subs_dict)
422
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
423
            self.step = self.step.subs(subs_dict)
424

425
426
427
428
429
430
431
432
433
434
    def fast_subs(self, subs_dict, skip=None):
        self.body = fast_subs(self.body, subs_dict, skip)
        if isinstance(self.start, sp.Basic):
            self.start = fast_subs(self.start, subs_dict, skip)
        if isinstance(self.stop, sp.Basic):
            self.stop = fast_subs(self.stop, subs_dict, skip)
        if isinstance(self.step, sp.Basic):
            self.step = fast_subs(self.step, subs_dict, skip)
        return self

435
436
    @property
    def args(self):
437
438
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
439
440
            if hasattr(e, "args"):
                result.append(e)
441
442
        return result

443
444
445
446
447
448
449
450
451
452
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

453
    @property
Martin Bauer's avatar
Martin Bauer committed
454
455
    def symbols_defined(self):
        return {self.loop_counter_symbol}
456
457

    @property
Martin Bauer's avatar
Martin Bauer committed
458
459
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
460
461
462
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
463
        return result - {self.loop_counter_symbol}
464

Martin Bauer's avatar
Martin Bauer committed
465
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
466
467
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
468

Martin Bauer's avatar
Martin Bauer committed
469
470
471
472
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)

473
    @property
Martin Bauer's avatar
Martin Bauer committed
474
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
475
476
477
478
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
479

480
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
481
    def is_loop_counter_symbol(symbol):
482
483
484
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
485
        if symbol.dtype != create_type('int'):
486
            return None
Martin Bauer's avatar
Martin Bauer committed
487
        coordinate = int(symbol.name[len(prefix) + 1:])
488
489
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
490
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
491
    def get_loop_counter_symbol(coordinate_to_loop_over):
492
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
493

Martin Bauer's avatar
Martin Bauer committed
494
495
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
496
497
498
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
                           'int',
                           nonnegative=True)
Martin Bauer's avatar
Martin Bauer committed
499

500
    @property
Martin Bauer's avatar
Martin Bauer committed
501
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
502
503
504
505
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
506
507

    @property
Martin Bauer's avatar
Martin Bauer committed
508
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
509
510
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
511
512

    @property
Martin Bauer's avatar
Martin Bauer committed
513
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
514
        return len(self.atoms(LoopOverCoordinate)) == 0
515

516
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
517
518
519
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
520
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
521
522

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
523
524
525
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
526

527
528

class SympyAssignment(Node):
529
    def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
Martin Bauer's avatar
Martin Bauer committed
530
        super(SympyAssignment, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
531
        self._lhs_symbol = lhs_symbol
532
        self.rhs = sp.sympify(rhs_expr)
Martin Bauer's avatar
Martin Bauer committed
533
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
534
        self._is_declaration = self.__is_declaration()
535
        self.use_auto = use_auto
Martin Bauer's avatar
Martin Bauer committed
536
537
538
539
540
541
542

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
543
544
545

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
546
        return self._lhs_symbol
547
548

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
549
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
550
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
551
        self._is_declaration = self.__is_declaration()
552

Martin Bauer's avatar
Martin Bauer committed
553
554
555
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
556

557
558
559
560
561
562
563
    def optimize(self, optimizations):
        try:
            from sympy.codegen.rewriting import optimize
            self.rhs = optimize(self.rhs, optimizations)
        except Exception:
            pass

564
565
    @property
    def args(self):
566
        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
567
568

    @property
Martin Bauer's avatar
Martin Bauer committed
569
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
570
        if not self._is_declaration:
571
            return set()
Martin Bauer's avatar
Martin Bauer committed
572
        return {self._lhs_symbol}
573
574

    @property
Martin Bauer's avatar
Martin Bauer committed
575
    def undefined_symbols(self):
Stephan Seitz's avatar
Stephan Seitz committed
576
        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
577
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
578
        loop_counters = set()
579
580
581
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
582
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
583
        result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
Martin Bauer's avatar
Martin Bauer committed
584
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
585
        result.update(self._lhs_symbol.atoms(sp.Symbol))
586
587
588
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
589
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
590
        return self._is_declaration
591
592

    @property
Martin Bauer's avatar
Martin Bauer committed
593
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
594
        return self._is_const
595

Jan Hoenig's avatar
Jan Hoenig committed
596
597
    def replace(self, child, replacement):
        if child == self.lhs:
598
599
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
600
601
602
603
604
605
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

606
    def __repr__(self):
607
        return repr(self.lhs) + " ← " + repr(self.rhs)
608

Martin Bauer's avatar
Martin Bauer committed
609
610
611
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
Martin Bauer's avatar
Martin Bauer committed
612
        return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
Martin Bauer's avatar
Martin Bauer committed
613

614

Martin Bauer's avatar
Martin Bauer committed
615
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
616
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
617
        if not isinstance(base, sp.IndexedBase):
618
            assert isinstance(base, TypedSymbol)
619
            base = sp.IndexedBase(base, shape=(1,))
620
            assert isinstance(base.label, TypedSymbol)
Martin Bauer's avatar
Martin Bauer committed
621
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
622
623
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
624
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
625
626
627
628
629
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
630
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
631

632
    def fast_subs(self, substitutions, skip=None):
Martin Bauer's avatar
Martin Bauer committed
633
634
635
636
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
637
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
638
639

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
640
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
641
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
642
643

    @property
Martin Bauer's avatar
Martin Bauer committed
644
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
645
646
647
648
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
649
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
650
651

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
652
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
653
654


655
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
656
657
658
659
660
661
662
663
664
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
665

Martin Bauer's avatar
Martin Bauer committed
666
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
667
668
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
669
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
670
671
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
672
673

    @property
Martin Bauer's avatar
Martin Bauer committed
674
675
    def symbols_defined(self):
        return {self.symbol}
676
677

    @property
Martin Bauer's avatar
Martin Bauer committed
678
    def undefined_symbols(self):
679
680
681
682
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
683
684
685

    @property
    def args(self):
686
        return [self.symbol]
687

Martin Bauer's avatar
Martin Bauer committed
688
689
690
691
692
693
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

694
695

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
696
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
697
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
698
699
700
701
702
703
704
705
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
706
707

    @property
Martin Bauer's avatar
Martin Bauer committed
708
    def symbols_defined(self):
709
710
711
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
712
    def undefined_symbols(self):
713
714
715
716
717
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
718
719
720
721
722


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
723
724


Stephan Seitz's avatar
Stephan Seitz committed
725
726
def get_dummy_symbol(dtype='bool'):
    return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772


class SourceCodeComment(Node):
    def __init__(self, text):
        self.text = text

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return "/* " + self.text + " */"

    def __repr__(self):
        return self.__str__()


class EmptyLine(Node):
    def __init__(self):
        pass

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return ""

    def __repr__(self):
        return self.__str__()
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797


class ConditionalFieldAccess(sp.Function):
    """
    :class:`pystencils.Field.Access` that is only executed if a certain condition is met.
    Can be used, for instance, for out-of-bound checks.
    """

    def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
        return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))

    @property
    def access(self):
        return self.args[0]

    @property
    def outofbounds_condition(self):
        return self.args[1]

    @property
    def outofbounds_value(self):
        return self.args[2]

    def __getnewargs__(self):
        return self.access, self.outofbounds_condition, self.outofbounds_value