astnodes.py 27.6 KB
Newer Older
1
2
import collections.abc
import itertools
Stephan Seitz's avatar
Stephan Seitz committed
3
import uuid
4
5
from typing import Any, List, Optional, Sequence, Set, Union

6
import sympy as sp
7

8
import pystencils
9
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
10
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
13
14

NodeOrExpr = Union['Node', sp.Expr]
15
16


17
class Node:
Martin Bauer's avatar
Martin Bauer committed
18
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
19

Martin Bauer's avatar
Martin Bauer committed
20
    def __init__(self, parent: Optional['Node'] = None):
21
22
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
23
24
25
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
26
        raise NotImplementedError()
27
28

    @property
Martin Bauer's avatar
Martin Bauer committed
29
30
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
31
        raise NotImplementedError()
32
33

    @property
Martin Bauer's avatar
Martin Bauer committed
34
35
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
36
        raise NotImplementedError()
37

Martin Bauer's avatar
Martin Bauer committed
38
    def subs(self, subs_dict) -> None:
39
        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
40
41
42
43
44
45
        for i, a in enumerate(self.args):
            result = a.subs(subs_dict)
            if isinstance(a, sp.Expr):  # sympy expressions' subs is out-of-place
                self.args[i] = result
            else:  # all other should be in-place
                assert result is None
46

47
48
49
50
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
51
52
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
53
54
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
55
            if isinstance(arg, arg_type):
56
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
57
            result.update(arg.atoms(arg_type))
58
59
60
        return result


61
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
62
63
64
65
66
67
68
69
70
71
72
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
73
74
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
75
76
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
77
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
78
79

        def handle_child(c):
80
81
82
83
84
85
86
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
87
88
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
89

Martin Bauer's avatar
Martin Bauer committed
90
91
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
92
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
93
94
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
95
96
97

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
98
99
100
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
101
102
103
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
104
    def symbols_defined(self):
105
106
107
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
108
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
109
110
111
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
112
113
        if hasattr(self.condition_expr, 'atoms'):
            result.update(self.condition_expr.atoms(sp.Symbol))
114
115
116
        return result

    def __str__(self):
117
        return self.__repr__()
118
119

    def __repr__(self):
120
        result = f'if:({self.condition_expr!r}) '
121
        if self.true_block:
122
            result += f'\n\t{self.true_block}) '
123
        if self.false_block:
124
125
            result = 'else: '
            result += f'\n\t{self.false_block} '
126

127
        return result
128

129
130
131
132
133
134
135
136
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

137

138
139
class KernelFunction(Node):

140
141
142
143
144
145
146
147
148
149
150
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
151

152
153
154
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
155

156
        def __repr__(self):
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
178

179
    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel", assignments=None):
180
181
        super(KernelFunction, self).__init__()
        self._body = body
182
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
183
        self.function_name = function_name
184
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
185
        self.ghost_layers = ghost_layers
186
187
        self._target = target
        self._backend = backend
188
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
189
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
190
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
191
192
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function
193
        self.assignments = assignments
194
195
196
197
198
199
200
201
202
203

    @property
    def target(self):
        """Currently either 'cpu' or 'gpu' """
        return self._target

    @property
    def backend(self):
        """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """
        return self._backend
204
205

    @property
Martin Bauer's avatar
Martin Bauer committed
206
    def symbols_defined(self):
207
208
209
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
210
    def undefined_symbols(self):
211
212
213
214
215
216
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
217
218
219
220
221
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

222
223
    @property
    def args(self):
224
        return self._body,
225

226
    @property
227
    def fields_accessed(self) -> Set[Field]:
228
        """Set of Field instances: fields which are accessed inside this kernel function"""
229
230
        from pystencils.interpolation_astnodes import InterpolatorAccess
        return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
231

232
    @property
233
    def fields_written(self) -> Set[Field]:
234
235
236
237
        assignments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}

    @property
238
    def fields_read(self) -> Set[Field]:
239
240
241
        assignments = self.atoms(SympyAssignment)
        return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
                                                         for a in assignments))
242

243
244
245
246
247
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
248
        field_map = {f.name: f for f in self.fields_accessed}
249
250
251
252
253
254
255

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
256

257
258
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
259
260
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
261
262
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
263

264
    def __str__(self):
265
266
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
267
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
268
269

    def __repr__(self):
270
        params = [p.symbol for p in self.get_parameters()]
271
        return f'{type(self).__name__} {self.function_name}({params})'
272

273
274
275
276
277
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

278

Martin Bauer's avatar
Martin Bauer committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


293
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
294
295
296
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
297
        self.parent = None
298
        for n in self._nodes:
Michael Kuron's avatar
Michael Kuron committed
299
300
301
302
            try:
                n.parent = self
            except AttributeError:
                pass
303
304
305
306
307

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
308
309
310
311
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

312
313
314
315
    def fast_subs(self, subs_dict, skip=None):
        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
        return self

316
317
318
    def insert_front(self, node, if_not_exists=False):
        if if_not_exists and len(self._nodes) > 0 and self._nodes[0] == node:
            return
319
320
321
322
323
324
325
326
327
        if isinstance(node, collections.abc.Iterable):
            node = list(node)
            for n in node:
                n.parent = self

            self._nodes = node + self._nodes
        else:
            node.parent = self
            self._nodes.insert(0, node)
328

329
    def insert_before(self, new_node, insert_before, if_not_exists=False):
Martin Bauer's avatar
Martin Bauer committed
330
        new_node.parent = self
331
        assert self._nodes.count(insert_before) == 1
Martin Bauer's avatar
Martin Bauer committed
332
        idx = self._nodes.index(insert_before)
333
334

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
335
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
336
337
338
339
340
341
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        if not if_not_exists or self._nodes[idx] != new_node:
            self._nodes.insert(idx, new_node)

    def insert_after(self, new_node, insert_after, if_not_exists=False):
        new_node.parent = self
        assert self._nodes.count(insert_after) == 1
        idx = self._nodes.index(insert_after) + 1

        # move all assignment (definitions to the top)
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
        if not if_not_exists or not (self._nodes[idx - 1] == new_node
                                     or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
            self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
361

362
    def append(self, node):
363
364
365
366
367
368
369
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
370

Martin Bauer's avatar
Martin Bauer committed
371
    def take_child_nodes(self):
372
373
374
375
376
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
377
        assert self._nodes.count(child) == 1
378
379
380
381
382
383
384
385
386
387
388
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
389
    def symbols_defined(self):
390
391
        result = set()
        for a in self.args:
392
393
394
395
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
            else:
                result.update(a.symbols_defined)
396
397
398
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
399
    def undefined_symbols(self):
400
        result = set()
Martin Bauer's avatar
Martin Bauer committed
401
        defined_symbols = set()
402
        for a in self.args:
403
404
405
406
407
408
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
                defined_symbols.update({a.lhs})
            else:
                result.update(a.undefined_symbols)
                defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
409
        return result - defined_symbols
410

411
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
412
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
413
414

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
415
        return "Block"
416

417
418

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
419
420
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
421
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
422
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
423
424
425
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
426
        return self.pragma_line
427
428
429
430


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
431
    BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
432

Martin Bauer's avatar
Martin Bauer committed
433
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
434
        super(LoopOverCoordinate, self).__init__(parent=None)
435
        self.body = body
436
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
437
        self.coordinate_to_loop_over = coordinate_to_loop_over
438
439
440
441
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
442
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
443
        self.is_block_loop = is_block_loop
444

Martin Bauer's avatar
Martin Bauer committed
445
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
446
447
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Markus Holzer's avatar
Markus Holzer committed
448
        result.prefix_lines = [l for l in self.prefix_lines]
449
450
        return result

Martin Bauer's avatar
Martin Bauer committed
451
452
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
453
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
454
            self.start = self.start.subs(subs_dict)
455
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
456
            self.stop = self.stop.subs(subs_dict)
457
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
458
            self.step = self.step.subs(subs_dict)
459

460
461
462
463
464
465
466
467
468
469
    def fast_subs(self, subs_dict, skip=None):
        self.body = fast_subs(self.body, subs_dict, skip)
        if isinstance(self.start, sp.Basic):
            self.start = fast_subs(self.start, subs_dict, skip)
        if isinstance(self.stop, sp.Basic):
            self.stop = fast_subs(self.stop, subs_dict, skip)
        if isinstance(self.step, sp.Basic):
            self.step = fast_subs(self.step, subs_dict, skip)
        return self

470
471
    @property
    def args(self):
472
473
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
474
475
            if hasattr(e, "args"):
                result.append(e)
476
477
        return result

478
479
480
481
482
483
484
485
486
487
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

488
    @property
Martin Bauer's avatar
Martin Bauer committed
489
490
    def symbols_defined(self):
        return {self.loop_counter_symbol}
491
492

    @property
Martin Bauer's avatar
Martin Bauer committed
493
494
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
495
496
497
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
498
        return result - {self.loop_counter_symbol}
499

Martin Bauer's avatar
Martin Bauer committed
500
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
501
    def get_loop_counter_name(coordinate_to_loop_over):
502
        return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
Martin Bauer's avatar
Martin Bauer committed
503

Martin Bauer's avatar
Martin Bauer committed
504
505
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
506
        return f"{LoopOverCoordinate.BLOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
Martin Bauer's avatar
Martin Bauer committed
507

508
    @property
Martin Bauer's avatar
Martin Bauer committed
509
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
510
511
512
513
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
514

515
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
516
    def is_loop_counter_symbol(symbol):
517
518
519
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
520
        if symbol.dtype != create_type('int'):
521
            return None
Martin Bauer's avatar
Martin Bauer committed
522
        coordinate = int(symbol.name[len(prefix) + 1:])
523
524
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
525
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
526
    def get_loop_counter_symbol(coordinate_to_loop_over):
527
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
528

Martin Bauer's avatar
Martin Bauer committed
529
530
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
531
532
533
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
                           'int',
                           nonnegative=True)
Martin Bauer's avatar
Martin Bauer committed
534

535
    @property
Martin Bauer's avatar
Martin Bauer committed
536
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
537
538
539
540
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
541
542

    @property
Martin Bauer's avatar
Martin Bauer committed
543
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
544
545
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
546
547

    @property
Martin Bauer's avatar
Martin Bauer committed
548
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
549
        return len(self.atoms(LoopOverCoordinate)) == 0
550

551
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
552
553
554
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
555
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
556
557

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
558
559
560
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
561

562
563

class SympyAssignment(Node):
564
    def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
Martin Bauer's avatar
Martin Bauer committed
565
        super(SympyAssignment, self).__init__(parent=None)
566
        self._lhs_symbol = sp.sympify(lhs_symbol)
567
        self.rhs = sp.sympify(rhs_expr)
Martin Bauer's avatar
Martin Bauer committed
568
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
569
        self._is_declaration = self.__is_declaration()
570
        self.use_auto = use_auto
Martin Bauer's avatar
Martin Bauer committed
571
572
573
574
575
576
577

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
578
579
580

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
581
        return self._lhs_symbol
582
583

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
584
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
585
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
586
        self._is_declaration = self.__is_declaration()
587

Martin Bauer's avatar
Martin Bauer committed
588
589
590
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
591

592
593
594
595
596
597
598
    def optimize(self, optimizations):
        try:
            from sympy.codegen.rewriting import optimize
            self.rhs = optimize(self.rhs, optimizations)
        except Exception:
            pass

599
600
    @property
    def args(self):
601
        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
602
603

    @property
Martin Bauer's avatar
Martin Bauer committed
604
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
605
        if not self._is_declaration:
606
            return set()
Martin Bauer's avatar
Martin Bauer committed
607
        return {self._lhs_symbol}
608
609

    @property
Martin Bauer's avatar
Martin Bauer committed
610
    def undefined_symbols(self):
Stephan Seitz's avatar
Stephan Seitz committed
611
        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
612
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
613
        loop_counters = set()
614
615
616
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
617
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
618
        result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
Martin Bauer's avatar
Martin Bauer committed
619
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
620
        result.update(self._lhs_symbol.atoms(sp.Symbol))
621
622
623
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
624
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
625
        return self._is_declaration
626
627

    @property
Martin Bauer's avatar
Martin Bauer committed
628
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
629
        return self._is_const
630

Jan Hoenig's avatar
Jan Hoenig committed
631
632
    def replace(self, child, replacement):
        if child == self.lhs:
633
634
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
635
636
637
638
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
639
            raise ValueError(f'{replacement} is not in args of {self.__class__}')
Jan Hoenig's avatar
Jan Hoenig committed
640

641
    def __repr__(self):
642
        return repr(self.lhs) + " ← " + repr(self.rhs)
643

Martin Bauer's avatar
Martin Bauer committed
644
645
646
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
647
        return f"${printed_lhs} \\leftarrow {printed_rhs}$"
Martin Bauer's avatar
Martin Bauer committed
648

649
650
651
652
653
654
    def __hash__(self):
        return hash((self.lhs, self.rhs))

    def __eq__(self, other):
        return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)

655

Martin Bauer's avatar
Martin Bauer committed
656
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
657
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
658
        if not isinstance(base, sp.IndexedBase):
659
            assert isinstance(base, TypedSymbol)
660
            base = sp.IndexedBase(base, shape=(1,))
661
            assert isinstance(base.label, TypedSymbol)
Martin Bauer's avatar
Martin Bauer committed
662
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
663
664
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
665
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
666
667
668
669
670
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
671
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
672

673
    def fast_subs(self, substitutions, skip=None):
Martin Bauer's avatar
Martin Bauer committed
674
675
676
677
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
678
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
679
680

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
681
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
682
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
683
684

    @property
Martin Bauer's avatar
Martin Bauer committed
685
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
686
687
688
689
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
690
        return f"{top} ({self.typed_symbol.dtype})"
Martin Bauer's avatar
Martin Bauer committed
691
692

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
693
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
694

Michael Kuron's avatar
Michael Kuron committed
695
696
697
    def __getnewargs_ex__(self):
        return (self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values), {}

Martin Bauer's avatar
Martin Bauer committed
698

699
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
700
701
702
703
704
705
706
707
708
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
709

Martin Bauer's avatar
Martin Bauer committed
710
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
711
712
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
713
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
714
715
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
716
717

    @property
Martin Bauer's avatar
Martin Bauer committed
718
719
    def symbols_defined(self):
        return {self.symbol}
720
721

    @property
Martin Bauer's avatar
Martin Bauer committed
722
    def undefined_symbols(self):
723
724
725
726
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
727
728
729

    @property
    def args(self):
730
        return [self.symbol]
731

Martin Bauer's avatar
Martin Bauer committed
732
733
734
735
736
737
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

738
739

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
740
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
741
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
742
743
744
745
746
747
748
749
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
750
751

    @property
Martin Bauer's avatar
Martin Bauer committed
752
    def symbols_defined(self):
753
754
755
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
756
    def undefined_symbols(self):
757
758
759
760
761
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
762
763
764
765
766


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
767
768


Stephan Seitz's avatar
Stephan Seitz committed
769
def get_dummy_symbol(dtype='bool'):
770
    return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype))
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816


class SourceCodeComment(Node):
    def __init__(self, text):
        self.text = text

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return "/* " + self.text + " */"

    def __repr__(self):
        return self.__str__()


class EmptyLine(Node):
    def __init__(self):
        pass

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return ""

    def __repr__(self):
        return self.__str__()
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841


class ConditionalFieldAccess(sp.Function):
    """
    :class:`pystencils.Field.Access` that is only executed if a certain condition is met.
    Can be used, for instance, for out-of-bound checks.
    """

    def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
        return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))

    @property
    def access(self):
        return self.args[0]

    @property
    def outofbounds_condition(self):
        return self.args[1]

    @property
    def outofbounds_value(self):
        return self.args[2]

    def __getnewargs__(self):
        return self.access, self.outofbounds_condition, self.outofbounds_value
842

Michael Kuron's avatar
Michael Kuron committed
843
844
    def __getnewargs_ex__(self):
        return (self.access, self.outofbounds_condition, self.outofbounds_value), {}