astnodes.py 26.3 KB
Newer Older
1
2
import collections.abc
import itertools
Stephan Seitz's avatar
Stephan Seitz committed
3
import uuid
4
5
from typing import Any, List, Optional, Sequence, Set, Union

6
import sympy as sp
7

8
import pystencils
9
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
10
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
13
14

NodeOrExpr = Union['Node', sp.Expr]
15
16


17
class Node:
Martin Bauer's avatar
Martin Bauer committed
18
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
19

Martin Bauer's avatar
Martin Bauer committed
20
    def __init__(self, parent: Optional['Node'] = None):
21
22
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
23
24
25
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
26
        raise NotImplementedError()
27
28

    @property
Martin Bauer's avatar
Martin Bauer committed
29
30
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
31
        raise NotImplementedError()
32
33

    @property
Martin Bauer's avatar
Martin Bauer committed
34
35
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
36
        raise NotImplementedError()
37

Martin Bauer's avatar
Martin Bauer committed
38
    def subs(self, subs_dict) -> None:
39
        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
40
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
41
            a.subs(subs_dict)
42

43
44
45
46
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
47
48
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
49
50
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
51
            if isinstance(arg, arg_type):
52
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
53
            result.update(arg.atoms(arg_type))
54
55
56
        return result


57
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
58
59
60
61
62
63
64
65
66
67
68
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
69
70
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
71
72
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
73
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
74
75

        def handle_child(c):
76
77
78
79
80
81
82
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
83
84
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
85

Martin Bauer's avatar
Martin Bauer committed
86
87
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
88
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
89
90
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
91
92
93

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
94
95
96
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
97
98
99
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
100
    def symbols_defined(self):
101
102
103
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
104
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
105
106
107
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
108
109
        if hasattr(self.condition_expr, 'atoms'):
            result.update(self.condition_expr.atoms(sp.Symbol))
110
111
112
        return result

    def __str__(self):
113
        return self.__repr__()
114
115

    def __repr__(self):
116
117
118
119
        repr = 'if:({!r}) '.format(self.condition_expr)
        if self.true_block:
            repr += '\n\t{}) '.format(self.true_block)
        if self.false_block:
Markus Holzer's avatar
Markus Holzer committed
120
            repr = 'else: '
121
122
123
            repr += '\n\t{} '.format(self.false_block)

        return repr
124

125
126
127
128
129
130
131
132
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

133

134
135
class KernelFunction(Node):

136
137
138
139
140
141
142
143
144
145
146
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
147

148
149
150
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
151

152
        def __repr__(self):
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
174

175
    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel", assignments=None):
176
177
        super(KernelFunction, self).__init__()
        self._body = body
178
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
179
        self.function_name = function_name
180
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
181
        self.ghost_layers = ghost_layers
182
183
        self._target = target
        self._backend = backend
184
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
185
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
186
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
187
188
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function
189
        self.assignments = assignments
190
191
192
193
194
195
196
197
198
199

    @property
    def target(self):
        """Currently either 'cpu' or 'gpu' """
        return self._target

    @property
    def backend(self):
        """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """
        return self._backend
200
201

    @property
Martin Bauer's avatar
Martin Bauer committed
202
    def symbols_defined(self):
203
204
205
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
206
    def undefined_symbols(self):
207
208
209
210
211
212
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
213
214
215
216
217
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

218
219
    @property
    def args(self):
220
        return self._body,
221

222
    @property
223
    def fields_accessed(self) -> Set[Field]:
224
        """Set of Field instances: fields which are accessed inside this kernel function"""
225
226
        from pystencils.interpolation_astnodes import InterpolatorAccess
        return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
227

228
    @property
229
    def fields_written(self) -> Set[Field]:
230
231
232
233
        assignments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}

    @property
234
    def fields_read(self) -> Set[Field]:
235
236
237
        assignments = self.atoms(SympyAssignment)
        return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
                                                         for a in assignments))
238

239
240
241
242
243
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
244
        field_map = {f.name: f for f in self.fields_accessed}
245
246
247
248
249
250
251

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
252

253
254
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
255
256
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
257
258
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
259

260
    def __str__(self):
261
262
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
263
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
264
265

    def __repr__(self):
266
267
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
268

269
270
271
272
273
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

274

Martin Bauer's avatar
Martin Bauer committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


289
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
290
291
292
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
293
        self.parent = None
294
        for n in self._nodes:
Michael Kuron's avatar
Michael Kuron committed
295
296
297
298
            try:
                n.parent = self
            except AttributeError:
                pass
299
300
301
302
303

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
304
305
306
307
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

308
309
310
311
    def fast_subs(self, subs_dict, skip=None):
        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
        return self

Martin Bauer's avatar
Martin Bauer committed
312
    def insert_front(self, node):
313
314
315
316
317
318
319
320
321
        if isinstance(node, collections.abc.Iterable):
            node = list(node)
            for n in node:
                n.parent = self

            self._nodes = node + self._nodes
        else:
            node.parent = self
            self._nodes.insert(0, node)
322

Martin Bauer's avatar
Martin Bauer committed
323
324
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
325
        assert self._nodes.count(insert_before) == 1
Martin Bauer's avatar
Martin Bauer committed
326
        idx = self._nodes.index(insert_before)
327
328

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
329
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
330
331
332
333
334
335
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
336
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
337

338
    def append(self, node):
339
340
341
342
343
344
345
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
346

Martin Bauer's avatar
Martin Bauer committed
347
    def take_child_nodes(self):
348
349
350
351
352
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
353
        assert self._nodes.count(child) == 1
354
355
356
357
358
359
360
361
362
363
364
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
365
    def symbols_defined(self):
366
367
        result = set()
        for a in self.args:
368
369
370
371
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
            else:
                result.update(a.symbols_defined)
372
373
374
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
375
    def undefined_symbols(self):
376
        result = set()
Martin Bauer's avatar
Martin Bauer committed
377
        defined_symbols = set()
378
        for a in self.args:
379
380
381
382
383
384
            if isinstance(a, pystencils.Assignment):
                result.update(a.free_symbols)
                defined_symbols.update({a.lhs})
            else:
                result.update(a.undefined_symbols)
                defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
385
        return result - defined_symbols
386

387
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
388
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
389
390

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
391
        return "Block"
392

393
394

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
395
396
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
397
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
398
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
399
400
401
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
402
        return self.pragma_line
403
404
405
406


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
Martin Bauer's avatar
Martin Bauer committed
407
    BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
408

Martin Bauer's avatar
Martin Bauer committed
409
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
410
        super(LoopOverCoordinate, self).__init__(parent=None)
411
        self.body = body
412
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
413
        self.coordinate_to_loop_over = coordinate_to_loop_over
414
415
416
417
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
418
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
419
        self.is_block_loop = is_block_loop
420

Martin Bauer's avatar
Martin Bauer committed
421
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
422
423
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Markus Holzer's avatar
Markus Holzer committed
424
        result.prefix_lines = [lo for lo in self.prefix_lines]
425
426
        return result

Martin Bauer's avatar
Martin Bauer committed
427
428
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
429
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
430
            self.start = self.start.subs(subs_dict)
431
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
432
            self.stop = self.stop.subs(subs_dict)
433
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
434
            self.step = self.step.subs(subs_dict)
435

436
437
438
439
440
441
442
443
444
445
    def fast_subs(self, subs_dict, skip=None):
        self.body = fast_subs(self.body, subs_dict, skip)
        if isinstance(self.start, sp.Basic):
            self.start = fast_subs(self.start, subs_dict, skip)
        if isinstance(self.stop, sp.Basic):
            self.stop = fast_subs(self.stop, subs_dict, skip)
        if isinstance(self.step, sp.Basic):
            self.step = fast_subs(self.step, subs_dict, skip)
        return self

446
447
    @property
    def args(self):
448
449
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
450
451
            if hasattr(e, "args"):
                result.append(e)
452
453
        return result

454
455
456
457
458
459
460
461
462
463
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

464
    @property
Martin Bauer's avatar
Martin Bauer committed
465
466
    def symbols_defined(self):
        return {self.loop_counter_symbol}
467
468

    @property
Martin Bauer's avatar
Martin Bauer committed
469
470
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
471
472
473
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
474
        return result - {self.loop_counter_symbol}
475

Martin Bauer's avatar
Martin Bauer committed
476
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
477
478
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
479

Martin Bauer's avatar
Martin Bauer committed
480
481
482
483
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)

484
    @property
Martin Bauer's avatar
Martin Bauer committed
485
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
486
487
488
489
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
490

491
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
492
    def is_loop_counter_symbol(symbol):
493
494
495
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
496
        if symbol.dtype != create_type('int'):
497
            return None
Martin Bauer's avatar
Martin Bauer committed
498
        coordinate = int(symbol.name[len(prefix) + 1:])
499
500
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
501
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
502
    def get_loop_counter_symbol(coordinate_to_loop_over):
503
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
504

Martin Bauer's avatar
Martin Bauer committed
505
506
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
507
508
509
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
                           'int',
                           nonnegative=True)
Martin Bauer's avatar
Martin Bauer committed
510

511
    @property
Martin Bauer's avatar
Martin Bauer committed
512
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
513
514
515
516
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
517
518

    @property
Martin Bauer's avatar
Martin Bauer committed
519
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
520
521
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
522
523

    @property
Martin Bauer's avatar
Martin Bauer committed
524
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
525
        return len(self.atoms(LoopOverCoordinate)) == 0
526

527
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
528
529
530
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
531
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
532
533

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
534
535
536
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
537

538
539

class SympyAssignment(Node):
540
    def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
Martin Bauer's avatar
Martin Bauer committed
541
        super(SympyAssignment, self).__init__(parent=None)
542
        self._lhs_symbol = sp.sympify(lhs_symbol)
543
        self.rhs = sp.sympify(rhs_expr)
Martin Bauer's avatar
Martin Bauer committed
544
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
545
        self._is_declaration = self.__is_declaration()
546
        self.use_auto = use_auto
Martin Bauer's avatar
Martin Bauer committed
547
548
549
550
551
552
553

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
554
555
556

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
557
        return self._lhs_symbol
558
559

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
560
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
561
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
562
        self._is_declaration = self.__is_declaration()
563

Martin Bauer's avatar
Martin Bauer committed
564
565
566
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
567

568
569
570
571
572
573
574
    def optimize(self, optimizations):
        try:
            from sympy.codegen.rewriting import optimize
            self.rhs = optimize(self.rhs, optimizations)
        except Exception:
            pass

575
576
    @property
    def args(self):
577
        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
578
579

    @property
Martin Bauer's avatar
Martin Bauer committed
580
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
581
        if not self._is_declaration:
582
            return set()
Martin Bauer's avatar
Martin Bauer committed
583
        return {self._lhs_symbol}
584
585

    @property
Martin Bauer's avatar
Martin Bauer committed
586
    def undefined_symbols(self):
Stephan Seitz's avatar
Stephan Seitz committed
587
        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
588
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
589
        loop_counters = set()
590
591
592
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
593
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
594
        result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
Martin Bauer's avatar
Martin Bauer committed
595
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
596
        result.update(self._lhs_symbol.atoms(sp.Symbol))
597
598
599
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
600
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
601
        return self._is_declaration
602
603

    @property
Martin Bauer's avatar
Martin Bauer committed
604
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
605
        return self._is_const
606

Jan Hoenig's avatar
Jan Hoenig committed
607
608
    def replace(self, child, replacement):
        if child == self.lhs:
609
610
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
611
612
613
614
615
616
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

617
    def __repr__(self):
618
        return repr(self.lhs) + " ← " + repr(self.rhs)
619

Martin Bauer's avatar
Martin Bauer committed
620
621
622
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
Martin Bauer's avatar
Martin Bauer committed
623
        return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
Martin Bauer's avatar
Martin Bauer committed
624

625
626
627
628
629
630
    def __hash__(self):
        return hash((self.lhs, self.rhs))

    def __eq__(self, other):
        return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)

631

Martin Bauer's avatar
Martin Bauer committed
632
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
633
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
634
        if not isinstance(base, sp.IndexedBase):
635
            assert isinstance(base, TypedSymbol)
636
            base = sp.IndexedBase(base, shape=(1,))
637
            assert isinstance(base.label, TypedSymbol)
Martin Bauer's avatar
Martin Bauer committed
638
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
639
640
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
641
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
642
643
644
645
646
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
647
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
648

649
    def fast_subs(self, substitutions, skip=None):
Martin Bauer's avatar
Martin Bauer committed
650
651
652
653
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
654
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
655
656

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
657
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
658
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
659
660

    @property
Martin Bauer's avatar
Martin Bauer committed
661
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
662
663
664
665
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
666
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
667
668

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
669
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
670
671


672
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
673
674
675
676
677
678
679
680
681
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
682

Martin Bauer's avatar
Martin Bauer committed
683
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
684
685
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
686
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
687
688
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
689
690

    @property
Martin Bauer's avatar
Martin Bauer committed
691
692
    def symbols_defined(self):
        return {self.symbol}
693
694

    @property
Martin Bauer's avatar
Martin Bauer committed
695
    def undefined_symbols(self):
696
697
698
699
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
700
701
702

    @property
    def args(self):
703
        return [self.symbol]
704

Martin Bauer's avatar
Martin Bauer committed
705
706
707
708
709
710
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

711
712

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
713
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
714
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
715
716
717
718
719
720
721
722
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
723
724

    @property
Martin Bauer's avatar
Martin Bauer committed
725
    def symbols_defined(self):
726
727
728
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
729
    def undefined_symbols(self):
730
731
732
733
734
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
735
736
737
738
739


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
740
741


Stephan Seitz's avatar
Stephan Seitz committed
742
743
def get_dummy_symbol(dtype='bool'):
    return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789


class SourceCodeComment(Node):
    def __init__(self, text):
        self.text = text

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return "/* " + self.text + " */"

    def __repr__(self):
        return self.__str__()


class EmptyLine(Node):
    def __init__(self):
        pass

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return ""

    def __repr__(self):
        return self.__str__()
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814


class ConditionalFieldAccess(sp.Function):
    """
    :class:`pystencils.Field.Access` that is only executed if a certain condition is met.
    Can be used, for instance, for out-of-bound checks.
    """

    def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
        return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))

    @property
    def access(self):
        return self.args[0]

    @property
    def outofbounds_condition(self):
        return self.args[1]

    @property
    def outofbounds_value(self):
        return self.args[2]

    def __getnewargs__(self):
        return self.access, self.outofbounds_condition, self.outofbounds_value