astnodes.py 27.6 KB
Newer Older
1
2
import collections.abc
import itertools
Stephan Seitz's avatar
Stephan Seitz committed
3
import uuid
4
5
from typing import Any, List, Optional, Sequence, Set, Union

6
import sympy as sp
7

8
import pystencils
9
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
Jan Hönig's avatar
Jan Hönig committed
10
from pystencils.enums import Target, Backend
11
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
13
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
14
15

NodeOrExpr = Union['Node', sp.Expr]
16
17


18
class Node:
Martin Bauer's avatar
Martin Bauer committed
19
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
20

Martin Bauer's avatar
Martin Bauer committed
21
    def __init__(self, parent: Optional['Node'] = None):
22
23
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
24
25
26
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
27
        raise NotImplementedError()
28
29

    @property
Martin Bauer's avatar
Martin Bauer committed
30
31
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
32
        raise NotImplementedError()
33
34

    @property
Martin Bauer's avatar
Martin Bauer committed
35
36
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
37
        raise NotImplementedError()
38

Martin Bauer's avatar
Martin Bauer committed
39
    def subs(self, subs_dict) -> None:
40
        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
41
42
43
44
45
46
        for i, a in enumerate(self.args):
            result = a.subs(subs_dict)
            if isinstance(a, sp.Expr):  # sympy expressions' subs is out-of-place
                self.args[i] = result
            else:  # all other should be in-place
                assert result is None
47

48
49
50
51
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
52
53
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
54
55
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
56
            if isinstance(arg, arg_type):
57
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
58
            result.update(arg.atoms(arg_type))
59
60
61
        return result


62
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
63
64
65
66
67
68
69
70
71
72
73
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
74
75
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
76
77
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
78
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
79
80

        def handle_child(c):
81
82
83
84
85
86
87
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
88
89
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
90

Martin Bauer's avatar
Martin Bauer committed
91
92
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
93
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
94
95
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
96
97
98

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
99
100
101
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
102
103
104
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
105
    def symbols_defined(self):
106
107
108
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
109
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
110
111
112
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
113
114
        if hasattr(self.condition_expr, 'atoms'):
            result.update(self.condition_expr.atoms(sp.Symbol))
115
116
117
        return result

    def __str__(self):
118
        return self.__repr__()
119
120

    def __repr__(self):
121
        result = f'if:({self.condition_expr!r}) '
122
        if self.true_block:
123
            result += f'\n\t{self.true_block}) '
124
        if self.false_block:
125
126
            result = 'else: '
            result += f'\n\t{self.false_block} '
127

128
        return result
129

130
131
132
133
134
135
136
137
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

138

139
class KernelFunction(Node):
140
141
142
143
144
145
146
147
148
149
150
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
151

152
153
154
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
155

156
        def __repr__(self):
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
178

Jan Hönig's avatar
Jan Hönig committed
179
180
181
    def __init__(self, body, target: Target, backend: Backend, compile_function, ghost_layers,
                 function_name: str = "kernel",
                 assignments=None):
182
183
        super(KernelFunction, self).__init__()
        self._body = body
184
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
185
        self.function_name = function_name
186
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
187
        self.ghost_layers = ghost_layers
188
189
        self._target = target
        self._backend = backend
190
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
191
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
192
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
193
194
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function
195
        self.assignments = assignments
196
197
198

    @property
    def target(self):
Jan Hönig's avatar
Jan Hönig committed
199
        """See pystencils.Target"""
200
201
202
203
        return self._target

    @property
    def backend(self):
Jan Hönig's avatar
Jan Hönig committed
204
        """Backend for generating the code: `Backend`"""
205
        return self._backend
206
207

    @property
Martin Bauer's avatar
Martin Bauer committed
208
    def symbols_defined(self):
209
210
211
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
212
    def undefined_symbols(self):
213
214
215
216
217
218
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
219
220
221
222
223
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

224
225
    @property
    def args(self):
226
        return self._body,
227

228
    @property
229
    def fields_accessed(self) -> Set[Field]:
230
        """Set of Field instances: fields which are accessed inside this kernel function"""
Markus Holzer's avatar
Markus Holzer committed
231
        return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess)))
232

233
    @property
234
    def fields_written(self) -> Set[Field]:
235
236
237
238
        assignments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}

    @property
239
    def fields_read(self) -> Set[Field]:
240
241
242
        assignments = self.atoms(SympyAssignment)
        return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
                                                         for a in assignments))
243

244
245
246
247
248
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
249
        field_map = {f.name: f for f in self.fields_accessed}
250
251
252
253
254
255
256

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
257

258
259
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
260
261
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
262
263
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
264

265
    def __str__(self):
266
267
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
268
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
269
270

    def __repr__(self):
271
        params = [p.symbol for p in self.get_parameters()]
272
        return f'{type(self).__name__} {self.function_name}({params})'
273

274
275
276
277
278
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

279

Martin Bauer's avatar
Martin Bauer committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


294
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
295
296
297
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
298
        self.parent = None
299
        for n in self._nodes:
Michael Kuron's avatar
Michael Kuron committed
300
301
302
303
            try:
                n.parent = self
            except AttributeError:
                pass
304
305
306
307
308

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
309
310
311
312
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

313
314
315
316
    def fast_subs(self, subs_dict, skip=None):
        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
        return self

317
318
319
    def insert_front(self, node, if_not_exists=False):
        if if_not_exists and len(self._nodes) > 0 and self._nodes[0] == node:
            return
320
321
322
323
324
325
326
327
328
        if isinstance(node, collections.abc.Iterable):
            node = list(node)
            for n in node:
                n.parent = self

            self._nodes = node + self._nodes
        else:
            node.parent = self
            self._nodes.insert(0, node)
329

330
    def insert_before(self, new_node, insert_before, if_not_exists=False):
Martin Bauer's avatar
Martin Bauer committed
331
        new_node.parent = self
332
        assert self._nodes.count(insert_before) == 1
Martin Bauer's avatar
Martin Bauer committed
333
        idx = self._nodes.index(insert_before)
334
335

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
336
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
337
338
339
340
341
342
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        if not if_not_exists or self._nodes[idx] != new_node:
            self._nodes.insert(idx, new_node)

    def insert_after(self, new_node, insert_after, if_not_exists=False):
        new_node.parent = self
        assert self._nodes.count(insert_after) == 1
        idx = self._nodes.index(insert_after) + 1

        # move all assignment (definitions to the top)
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
        if not if_not_exists or not (self._nodes[idx - 1] == new_node
                                     or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
            self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
362

363
    def append(self, node):
364
365
366
367
368
369
370
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
371

Martin Bauer's avatar
Martin Bauer committed
372
    def take_child_nodes(self):
373
374
375
376
377
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
378
        assert self._nodes.count(child) == 1
379
380
381
382
383
384
385
386
387
388
389
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
390
    def symbols_defined(self):
391
392
        result = set()
        for a in self.args:
393
394
395
396
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
            else:
                result.update(a.symbols_defined)
397
398
399
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
400
    def undefined_symbols(self):
401
        result = set()
Martin Bauer's avatar
Martin Bauer committed
402
        defined_symbols = set()
403
        for a in self.args:
404
405
406
407
408
409
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
                defined_symbols.update({a.lhs})
            else:
                result.update(a.undefined_symbols)
                defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
410
        return result - defined_symbols
411

412
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
413
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
414
415

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
416
        return "Block"
417

418
419

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
420
421
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
422
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
423
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
424
425
426
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
427
        return self.pragma_line
428
429
430
431


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
432
    BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
433

Martin Bauer's avatar
Martin Bauer committed
434
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
435
        super(LoopOverCoordinate, self).__init__(parent=None)
436
        self.body = body
437
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
438
        self.coordinate_to_loop_over = coordinate_to_loop_over
439
440
441
442
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
443
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
444
        self.is_block_loop = is_block_loop
445

Martin Bauer's avatar
Martin Bauer committed
446
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
447
448
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Markus Holzer's avatar
Markus Holzer committed
449
        result.prefix_lines = [l for l in self.prefix_lines]
450
451
        return result

Martin Bauer's avatar
Martin Bauer committed
452
453
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
454
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
455
            self.start = self.start.subs(subs_dict)
456
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
457
            self.stop = self.stop.subs(subs_dict)
458
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
459
            self.step = self.step.subs(subs_dict)
460

461
462
463
464
465
466
467
468
469
470
    def fast_subs(self, subs_dict, skip=None):
        self.body = fast_subs(self.body, subs_dict, skip)
        if isinstance(self.start, sp.Basic):
            self.start = fast_subs(self.start, subs_dict, skip)
        if isinstance(self.stop, sp.Basic):
            self.stop = fast_subs(self.stop, subs_dict, skip)
        if isinstance(self.step, sp.Basic):
            self.step = fast_subs(self.step, subs_dict, skip)
        return self

471
472
    @property
    def args(self):
473
474
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
475
476
            if hasattr(e, "args"):
                result.append(e)
477
478
        return result

479
480
481
482
483
484
485
486
487
488
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

489
    @property
Martin Bauer's avatar
Martin Bauer committed
490
491
    def symbols_defined(self):
        return {self.loop_counter_symbol}
492
493

    @property
Martin Bauer's avatar
Martin Bauer committed
494
495
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
496
497
498
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
499
        return result - {self.loop_counter_symbol}
500

Martin Bauer's avatar
Martin Bauer committed
501
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
502
    def get_loop_counter_name(coordinate_to_loop_over):
503
        return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
Martin Bauer's avatar
Martin Bauer committed
504

Martin Bauer's avatar
Martin Bauer committed
505
506
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
507
        return f"{LoopOverCoordinate.BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
Martin Bauer's avatar
Martin Bauer committed
508

509
    @property
Martin Bauer's avatar
Martin Bauer committed
510
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
511
512
513
514
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
515

516
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
517
    def is_loop_counter_symbol(symbol):
518
519
520
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
521
        if symbol.dtype != create_type('int'):
522
            return None
Martin Bauer's avatar
Martin Bauer committed
523
        coordinate = int(symbol.name[len(prefix) + 1:])
524
525
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
526
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
527
    def get_loop_counter_symbol(coordinate_to_loop_over):
528
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
529

Martin Bauer's avatar
Martin Bauer committed
530
531
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
532
533
534
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
                           'int',
                           nonnegative=True)
Martin Bauer's avatar
Martin Bauer committed
535

536
    @property
Martin Bauer's avatar
Martin Bauer committed
537
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
538
539
540
541
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
542
543

    @property
Martin Bauer's avatar
Martin Bauer committed
544
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
545
546
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
547
548

    @property
Martin Bauer's avatar
Martin Bauer committed
549
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
550
        return len(self.atoms(LoopOverCoordinate)) == 0
551

552
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
553
554
555
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
556
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
557
558

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
559
560
561
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
562

563
564

class SympyAssignment(Node):
565
    def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
Martin Bauer's avatar
Martin Bauer committed
566
        super(SympyAssignment, self).__init__(parent=None)
567
        self._lhs_symbol = sp.sympify(lhs_symbol)
568
        self.rhs = sp.sympify(rhs_expr)
Martin Bauer's avatar
Martin Bauer committed
569
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
570
        self._is_declaration = self.__is_declaration()
571
        self.use_auto = use_auto
Martin Bauer's avatar
Martin Bauer committed
572
573
574
575
576
577
578

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
579
580
581

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
582
        return self._lhs_symbol
583
584

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
585
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
586
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
587
        self._is_declaration = self.__is_declaration()
588

Martin Bauer's avatar
Martin Bauer committed
589
590
591
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
592

593
594
595
596
597
598
599
    def optimize(self, optimizations):
        try:
            from sympy.codegen.rewriting import optimize
            self.rhs = optimize(self.rhs, optimizations)
        except Exception:
            pass

600
601
    @property
    def args(self):
602
        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
603
604

    @property
Martin Bauer's avatar
Martin Bauer committed
605
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
606
        if not self._is_declaration:
607
            return set()
Martin Bauer's avatar
Martin Bauer committed
608
        return {self._lhs_symbol}
609
610

    @property
Martin Bauer's avatar
Martin Bauer committed
611
    def undefined_symbols(self):
Stephan Seitz's avatar
Stephan Seitz committed
612
        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
613
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
614
        loop_counters = set()
615
616
617
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
618
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
619
        result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
Martin Bauer's avatar
Martin Bauer committed
620
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
621
        result.update(self._lhs_symbol.atoms(sp.Symbol))
622
623
624
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
625
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
626
        return self._is_declaration
627
628

    @property
Martin Bauer's avatar
Martin Bauer committed
629
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
630
        return self._is_const
631

Jan Hoenig's avatar
Jan Hoenig committed
632
633
    def replace(self, child, replacement):
        if child == self.lhs:
634
635
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
636
637
638
639
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
640
            raise ValueError(f'{replacement} is not in args of {self.__class__}')
Jan Hoenig's avatar
Jan Hoenig committed
641

642
    def __repr__(self):
643
        return repr(self.lhs) + " ← " + repr(self.rhs)
644

Martin Bauer's avatar
Martin Bauer committed
645
646
647
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
648
        return f"${printed_lhs} \\leftarrow {printed_rhs}$"
Martin Bauer's avatar
Martin Bauer committed
649

650
651
652
653
654
655
    def __hash__(self):
        return hash((self.lhs, self.rhs))

    def __eq__(self, other):
        return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)

656

Martin Bauer's avatar
Martin Bauer committed
657
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
658
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
659
        if not isinstance(base, sp.IndexedBase):
660
            assert isinstance(base, TypedSymbol)
661
            base = sp.IndexedBase(base, shape=(1,))
662
            assert isinstance(base.label, TypedSymbol)
Martin Bauer's avatar
Martin Bauer committed
663
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
664
665
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
666
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
667
668
669
670
671
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
672
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
673

674
    def fast_subs(self, substitutions, skip=None):
Martin Bauer's avatar
Martin Bauer committed
675
676
677
678
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
679
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
680
681

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
682
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
683
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
684
685

    @property
Martin Bauer's avatar
Martin Bauer committed
686
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
687
688
689
690
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
691
        return f"{top} ({self.typed_symbol.dtype})"
Martin Bauer's avatar
Martin Bauer committed
692
693

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
694
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
695

Michael Kuron's avatar
Michael Kuron committed
696
697
698
    def __getnewargs_ex__(self):
        return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}

Martin Bauer's avatar
Martin Bauer committed
699

700
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
701
702
703
704
705
706
707
708
709
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
710

Martin Bauer's avatar
Martin Bauer committed
711
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
712
713
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
714
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
715
716
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
717
718

    @property
Martin Bauer's avatar
Martin Bauer committed
719
720
    def symbols_defined(self):
        return {self.symbol}
721
722

    @property
Martin Bauer's avatar
Martin Bauer committed
723
    def undefined_symbols(self):
724
725
726
727
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
728
729
730

    @property
    def args(self):
731
        return [self.symbol]
732

Martin Bauer's avatar
Martin Bauer committed
733
734
735
736
737
738
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

739
740

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
741
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
742
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
743
744
745
746
747
748
749
750
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
751
752

    @property
Martin Bauer's avatar
Martin Bauer committed
753
    def symbols_defined(self):
754
755
756
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
757
    def undefined_symbols(self):
758
759
760
761
762
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
763
764
765
766
767


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
768
769


Stephan Seitz's avatar
Stephan Seitz committed
770
def get_dummy_symbol(dtype='bool'):
771
    return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817


class SourceCodeComment(Node):
    def __init__(self, text):
        self.text = text

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return "/* " + self.text + " */"

    def __repr__(self):
        return self.__str__()


class EmptyLine(Node):
    def __init__(self):
        pass

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return ""

    def __repr__(self):
        return self.__str__()
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842


class ConditionalFieldAccess(sp.Function):
    """
    :class:`pystencils.Field.Access` that is only executed if a certain condition is met.
    Can be used, for instance, for out-of-bound checks.
    """

    def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
        return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))

    @property
    def access(self):
        return self.args[0]

    @property
    def outofbounds_condition(self):
        return self.args[1]

    @property
    def outofbounds_value(self):
        return self.args[2]

    def __getnewargs__(self):
        return self.access, self.outofbounds_condition, self.outofbounds_value
843

Michael Kuron's avatar
Michael Kuron committed
844
845
    def __getnewargs_ex__(self):
        return (self.access, self.outofbounds_condition, self.outofbounds_value), {}