astnodes.py 25.5 KB
Newer Older
1
2
import collections.abc
import itertools
Stephan Seitz's avatar
Stephan Seitz committed
3
import uuid
4
5
from typing import Any, List, Optional, Sequence, Set, Union

6
import sympy as sp
7

8
from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type
9
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.sympyextensions import fast_subs
Martin Bauer's avatar
Martin Bauer committed
12
13

NodeOrExpr = Union['Node', sp.Expr]
14
15


16
class Node:
Martin Bauer's avatar
Martin Bauer committed
17
    """Base class for all AST nodes."""
Martin Bauer's avatar
Martin Bauer committed
18

Martin Bauer's avatar
Martin Bauer committed
19
    def __init__(self, parent: Optional['Node'] = None):
20
21
        self.parent = parent

Martin Bauer's avatar
Martin Bauer committed
22
23
24
    @property
    def args(self) -> List[NodeOrExpr]:
        """Returns all arguments/children of this node."""
Martin Bauer's avatar
Martin Bauer committed
25
        raise NotImplementedError()
26
27

    @property
Martin Bauer's avatar
Martin Bauer committed
28
29
    def symbols_defined(self) -> Set[sp.Symbol]:
        """Set of symbols which are defined by this node."""
Martin Bauer's avatar
Martin Bauer committed
30
        raise NotImplementedError()
31
32

    @property
Martin Bauer's avatar
Martin Bauer committed
33
34
    def undefined_symbols(self) -> Set[sp.Symbol]:
        """Symbols which are used but are not defined inside this node."""
35
        raise NotImplementedError()
36

Martin Bauer's avatar
Martin Bauer committed
37
    def subs(self, subs_dict) -> None:
38
        """Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
39
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
40
            a.subs(subs_dict)
41

42
43
44
45
    @property
    def func(self):
        return self.__class__

Martin Bauer's avatar
Martin Bauer committed
46
47
    def atoms(self, arg_type) -> Set[Any]:
        """Returns a set of all descendants recursively, which are an instance of the given type."""
48
49
        result = set()
        for arg in self.args:
Martin Bauer's avatar
Martin Bauer committed
50
            if isinstance(arg, arg_type):
51
                result.add(arg)
Martin Bauer's avatar
Martin Bauer committed
52
            result.update(arg.atoms(arg_type))
53
54
55
        return result


56
class Conditional(Node):
Martin Bauer's avatar
Martin Bauer committed
57
58
59
60
61
62
63
64
65
66
67
    """Conditional that maps to a 'if' statement in C/C++.

    Try to avoid using this node inside of loops, since currently this construction can not be vectorized.
    Consider using assignments with sympy.Piecewise in this case.

    Args:
        condition_expr: sympy relational expression
        true_block: block which is run if conditional is true
        false_block: optional block which is run if conditional is false
    """

Martin Bauer's avatar
Martin Bauer committed
68
69
    def __init__(self, condition_expr: sp.Basic, true_block: Union['Block', 'SympyAssignment'],
                 false_block: Optional['Block'] = None) -> None:
Martin Bauer's avatar
Martin Bauer committed
70
71
        super(Conditional, self).__init__(parent=None)

Martin Bauer's avatar
Martin Bauer committed
72
        self.condition_expr = condition_expr
Martin Bauer's avatar
Martin Bauer committed
73
74

        def handle_child(c):
75
76
77
78
79
80
81
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

Martin Bauer's avatar
Martin Bauer committed
82
83
        self.true_block = handle_child(true_block)
        self.false_block = handle_child(false_block)
84

Martin Bauer's avatar
Martin Bauer committed
85
86
    def subs(self, subs_dict):
        self.true_block.subs(subs_dict)
Martin Bauer's avatar
Martin Bauer committed
87
        if self.false_block:
Martin Bauer's avatar
Martin Bauer committed
88
89
            self.false_block.subs(subs_dict)
        self.condition_expr = self.condition_expr.subs(subs_dict)
90
91
92

    @property
    def args(self):
Martin Bauer's avatar
Martin Bauer committed
93
94
95
        result = [self.condition_expr, self.true_block]
        if self.false_block:
            result.append(self.false_block)
96
97
98
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
99
    def symbols_defined(self):
100
101
102
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
103
    def undefined_symbols(self):
Martin Bauer's avatar
Martin Bauer committed
104
105
106
        result = self.true_block.undefined_symbols
        if self.false_block:
            result.update(self.false_block.undefined_symbols)
107
108
        if hasattr(self.condition_expr, 'atoms'):
            result.update(self.condition_expr.atoms(sp.Symbol))
109
110
111
        return result

    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
112
        return 'if:({!s}) '.format(self.condition_expr)
113
114

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
115
        return 'if:({!r}) '.format(self.condition_expr)
116

117
118
119
120
121
122
123
124
    def replace_by_true_block(self):
        """Replaces the conditional by its True block"""
        self.parent.replace(self, [self.true_block])

    def replace_by_false_block(self):
        """Replaces the conditional by its False block"""
        self.parent.replace(self, [self.false_block] if self.false_block else [])

125

126
127
class KernelFunction(Node):

128
129
130
131
132
133
134
135
136
137
138
    class Parameter:
        """Function parameter.

        Each undefined symbol in a `KernelFunction` node becomes a parameter to the function.
        Parameters are either symbols introduced by the user that never occur on the left hand side of an
        Assignment, or are related to fields/arrays passed to the function.

        A parameter consists of the typed symbol (symbol property). For field related parameters this is a symbol
        defined in pystencils.kernelparameters.
        If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
        """
139

140
141
142
        def __init__(self, symbol, fields):
            self.symbol = symbol  # type: TypedSymbol
            self.fields = fields  # type: Sequence[Field]
143

144
        def __repr__(self):
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
            return repr(self.symbol)

        @property
        def is_field_stride(self):
            return isinstance(self.symbol, FieldStrideSymbol)

        @property
        def is_field_shape(self):
            return isinstance(self.symbol, FieldShapeSymbol)

        @property
        def is_field_pointer(self):
            return isinstance(self.symbol, FieldPointerSymbol)

        @property
        def is_field_parameter(self):
            return self.is_field_pointer or self.is_field_shape or self.is_field_stride

        @property
        def field_name(self):
            return self.fields[0].name
166

167
    def __init__(self, body, target, backend, compile_function, ghost_layers, function_name="kernel"):
168
169
        super(KernelFunction, self).__init__()
        self._body = body
170
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
171
        self.function_name = function_name
172
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
173
        self.ghost_layers = ghost_layers
174
175
        self._target = target
        self._backend = backend
176
        # these variables are assumed to be global, so no automatic parameter is generated for them
Martin Bauer's avatar
Martin Bauer committed
177
        self.global_variables = set()
Martin Bauer's avatar
Martin Bauer committed
178
        self.instruction_set = None  # used in `vectorize` function to tell the backend which i.s. (SSE,AVX) to use
179
180
181
182
183
184
185
186
187
188
189
190
        # function that compiles the node to a Python callable, is set by the backends
        self._compile_function = compile_function

    @property
    def target(self):
        """Currently either 'cpu' or 'gpu' """
        return self._target

    @property
    def backend(self):
        """Backend for generating the code e.g. 'llvm', 'c', 'cuda' """
        return self._backend
191
192

    @property
Martin Bauer's avatar
Martin Bauer committed
193
    def symbols_defined(self):
194
195
196
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
197
    def undefined_symbols(self):
198
199
200
201
202
203
        return set()

    @property
    def body(self):
        return self._body

Martin Bauer's avatar
Martin Bauer committed
204
205
206
207
208
    @body.setter
    def body(self, value):
        self._body = value
        self._body.parent = self

209
210
    @property
    def args(self):
211
        return self._body,
212

213
    @property
214
    def fields_accessed(self) -> Set[Field]:
215
        """Set of Field instances: fields which are accessed inside this kernel function"""
216
217
        from pystencils.interpolation_astnodes import InterpolatorAccess
        return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
218

219
    @property
220
    def fields_written(self) -> Set[Field]:
221
222
223
224
        assignments = self.atoms(SympyAssignment)
        return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}

    @property
225
    def fields_read(self) -> Set[Field]:
226
227
228
        assignments = self.atoms(SympyAssignment)
        return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
                                                         for a in assignments))
229

230
231
232
233
234
    def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
        """Returns list of parameters for this function.

        This function is expensive, cache the result where possible!
        """
235
        field_map = {f.name: f for f in self.fields_accessed}
236
237
238
239
240
241
242

        def get_fields(symbol):
            if hasattr(symbol, 'field_name'):
                return field_map[symbol.field_name],
            elif hasattr(symbol, 'field_names'):
                return tuple(field_map[fn] for fn in symbol.field_names)
            return ()
243

244
245
        argument_symbols = self._body.undefined_symbols - self.global_variables
        parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
246
247
        if hasattr(self, 'indexing'):
            parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
248
249
        parameters.sort(key=lambda p: p.symbol.name)
        return parameters
250

251
    def __str__(self):
252
253
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.function_name, params,
Michael Kuron's avatar
Michael Kuron committed
254
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
255
256

    def __repr__(self):
257
258
        params = [p.symbol for p in self.get_parameters()]
        return '{0} {1}({2})'.format(type(self).__name__, self.function_name, params)
259

260
261
262
263
264
    def compile(self, *args, **kwargs):
        if self._compile_function is None:
            raise ValueError("No compile-function provided for this KernelFunction node")
        return self._compile_function(self, *args, **kwargs)

265

Martin Bauer's avatar
Martin Bauer committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
class SkipIteration(Node):
    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()


280
class Block(Node):
Martin Bauer's avatar
Martin Bauer committed
281
282
283
    def __init__(self, nodes: List[Node]):
        super(Block, self).__init__()
        self._nodes = nodes
284
        self.parent = None
285
286
287
288
289
290
291
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

Martin Bauer's avatar
Martin Bauer committed
292
293
294
295
    def subs(self, subs_dict) -> None:
        for a in self.args:
            a.subs(subs_dict)

296
297
298
299
    def fast_subs(self, subs_dict, skip=None):
        self._nodes = [fast_subs(a, subs_dict, skip) for a in self._nodes]
        return self

Martin Bauer's avatar
Martin Bauer committed
300
    def insert_front(self, node):
301
302
303
304
305
306
307
308
309
        if isinstance(node, collections.abc.Iterable):
            node = list(node)
            for n in node:
                n.parent = self

            self._nodes = node + self._nodes
        else:
            node.parent = self
            self._nodes.insert(0, node)
310

Martin Bauer's avatar
Martin Bauer committed
311
312
    def insert_before(self, new_node, insert_before):
        new_node.parent = self
313
        assert self._nodes.count(insert_before) == 1
Martin Bauer's avatar
Martin Bauer committed
314
        idx = self._nodes.index(insert_before)
315
316

        # move all assignment (definitions to the top)
Martin Bauer's avatar
Martin Bauer committed
317
        if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
Martin Bauer's avatar
Martin Bauer committed
318
319
320
321
322
323
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
324
        self._nodes.insert(idx, new_node)
Martin Bauer's avatar
Martin Bauer committed
325

326
    def append(self, node):
327
328
329
330
331
332
333
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
334

Martin Bauer's avatar
Martin Bauer committed
335
    def take_child_nodes(self):
336
337
338
339
340
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
341
        assert self._nodes.count(child) == 1
342
343
344
345
346
347
348
349
350
351
352
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
Martin Bauer's avatar
Martin Bauer committed
353
    def symbols_defined(self):
354
355
        result = set()
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
356
            result.update(a.symbols_defined)
357
358
359
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
360
    def undefined_symbols(self):
361
        result = set()
Martin Bauer's avatar
Martin Bauer committed
362
        defined_symbols = set()
363
        for a in self.args:
Martin Bauer's avatar
Martin Bauer committed
364
365
            result.update(a.undefined_symbols)
            defined_symbols.update(a.symbols_defined)
Martin Bauer's avatar
Martin Bauer committed
366
        return result - defined_symbols
367

368
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
369
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
370
371

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
372
        return "Block"
373

374
375

class PragmaBlock(Block):
Martin Bauer's avatar
Martin Bauer committed
376
377
    def __init__(self, pragma_line, nodes):
        super(PragmaBlock, self).__init__(nodes)
Martin Bauer's avatar
Martin Bauer committed
378
        self.pragma_line = pragma_line
Martin Bauer's avatar
Martin Bauer committed
379
        for n in nodes:
Martin Bauer's avatar
Martin Bauer committed
380
381
382
            n.parent = self

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
383
        return self.pragma_line
384
385
386
387


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"
Martin Bauer's avatar
Martin Bauer committed
388
    BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
389

Martin Bauer's avatar
Martin Bauer committed
390
    def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
Martin Bauer's avatar
Martin Bauer committed
391
        super(LoopOverCoordinate, self).__init__(parent=None)
392
        self.body = body
393
        body.parent = self
Martin Bauer's avatar
Martin Bauer committed
394
        self.coordinate_to_loop_over = coordinate_to_loop_over
395
396
397
398
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
Martin Bauer's avatar
Martin Bauer committed
399
        self.prefix_lines = []
Martin Bauer's avatar
Martin Bauer committed
400
        self.is_block_loop = is_block_loop
401

Martin Bauer's avatar
Martin Bauer committed
402
    def new_loop_with_different_body(self, new_body):
Martin Bauer's avatar
Martin Bauer committed
403
404
        result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
                                    self.step, self.is_block_loop)
Martin Bauer's avatar
Martin Bauer committed
405
        result.prefix_lines = [l for l in self.prefix_lines]
406
407
        return result

Martin Bauer's avatar
Martin Bauer committed
408
409
    def subs(self, subs_dict):
        self.body.subs(subs_dict)
410
        if hasattr(self.start, "subs"):
Martin Bauer's avatar
Martin Bauer committed
411
            self.start = self.start.subs(subs_dict)
412
        if hasattr(self.stop, "subs"):
Martin Bauer's avatar
Martin Bauer committed
413
            self.stop = self.stop.subs(subs_dict)
414
        if hasattr(self.step, "subs"):
Martin Bauer's avatar
Martin Bauer committed
415
            self.step = self.step.subs(subs_dict)
416

417
418
419
420
421
422
423
424
425
426
    def fast_subs(self, subs_dict, skip=None):
        self.body = fast_subs(self.body, subs_dict, skip)
        if isinstance(self.start, sp.Basic):
            self.start = fast_subs(self.start, subs_dict, skip)
        if isinstance(self.stop, sp.Basic):
            self.stop = fast_subs(self.stop, subs_dict, skip)
        if isinstance(self.step, sp.Basic):
            self.step = fast_subs(self.step, subs_dict, skip)
        return self

427
428
    @property
    def args(self):
429
430
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
431
432
            if hasattr(e, "args"):
                result.append(e)
433
434
        return result

435
436
437
438
439
440
441
442
443
444
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

445
    @property
Martin Bauer's avatar
Martin Bauer committed
446
447
    def symbols_defined(self):
        return {self.loop_counter_symbol}
448
449

    @property
Martin Bauer's avatar
Martin Bauer committed
450
451
    def undefined_symbols(self):
        result = self.body.undefined_symbols
Martin Bauer's avatar
Martin Bauer committed
452
453
454
        for possible_symbol in [self.start, self.stop, self.step]:
            if isinstance(possible_symbol, Node) or isinstance(possible_symbol, sp.Basic):
                result.update(possible_symbol.atoms(sp.Symbol))
Martin Bauer's avatar
Martin Bauer committed
455
        return result - {self.loop_counter_symbol}
456

Martin Bauer's avatar
Martin Bauer committed
457
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
458
459
    def get_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
460

Martin Bauer's avatar
Martin Bauer committed
461
462
463
464
    @staticmethod
    def get_block_loop_counter_name(coordinate_to_loop_over):
        return "%s_%s" % (LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX, coordinate_to_loop_over)

465
    @property
Martin Bauer's avatar
Martin Bauer committed
466
    def loop_counter_name(self):
Martin Bauer's avatar
Martin Bauer committed
467
468
469
470
        if self.is_block_loop:
            return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
        else:
            return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
Martin Bauer's avatar
Martin Bauer committed
471

472
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
473
    def is_loop_counter_symbol(symbol):
474
475
476
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
477
        if symbol.dtype != create_type('int'):
478
            return None
Martin Bauer's avatar
Martin Bauer committed
479
        coordinate = int(symbol.name[len(prefix) + 1:])
480
481
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
482
    @staticmethod
Martin Bauer's avatar
Martin Bauer committed
483
    def get_loop_counter_symbol(coordinate_to_loop_over):
484
        return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
485

Martin Bauer's avatar
Martin Bauer committed
486
487
    @staticmethod
    def get_block_loop_counter_symbol(coordinate_to_loop_over):
488
489
490
        return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
                           'int',
                           nonnegative=True)
Martin Bauer's avatar
Martin Bauer committed
491

492
    @property
Martin Bauer's avatar
Martin Bauer committed
493
    def loop_counter_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
494
495
496
497
        if self.is_block_loop:
            return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
        else:
            return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
498
499

    @property
Martin Bauer's avatar
Martin Bauer committed
500
    def is_outermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
501
502
        from pystencils.transformations import get_next_parent_of_type
        return get_next_parent_of_type(self, LoopOverCoordinate) is None
503
504

    @property
Martin Bauer's avatar
Martin Bauer committed
505
    def is_innermost_loop(self):
Martin Bauer's avatar
Martin Bauer committed
506
        return len(self.atoms(LoopOverCoordinate)) == 0
507

508
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
509
510
511
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
                                                                    self.loop_counter_name, self.stop,
                                                                    self.loop_counter_name, self.step,
512
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
513
514

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
515
516
517
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
                                                              self.loop_counter_name, self.stop,
                                                              self.loop_counter_name, self.step)
518

519
520

class SympyAssignment(Node):
521
    def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
Martin Bauer's avatar
Martin Bauer committed
522
        super(SympyAssignment, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
523
        self._lhs_symbol = lhs_symbol
524
        self.rhs = sp.sympify(rhs_expr)
Martin Bauer's avatar
Martin Bauer committed
525
        self._is_const = is_const
Martin Bauer's avatar
Martin Bauer committed
526
        self._is_declaration = self.__is_declaration()
527
        self.use_auto = use_auto
Martin Bauer's avatar
Martin Bauer committed
528
529
530
531
532
533
534

    def __is_declaration(self):
        if isinstance(self._lhs_symbol, cast_func):
            return False
        if any(isinstance(self._lhs_symbol, c) for c in (Field.Access, sp.Indexed, TemporaryMemoryAllocation)):
            return False
        return True
535
536
537

    @property
    def lhs(self):
Martin Bauer's avatar
Martin Bauer committed
538
        return self._lhs_symbol
539
540

    @lhs.setter
Martin Bauer's avatar
Martin Bauer committed
541
    def lhs(self, new_value):
Martin Bauer's avatar
Martin Bauer committed
542
        self._lhs_symbol = new_value
Martin Bauer's avatar
Martin Bauer committed
543
        self._is_declaration = self.__is_declaration()
544

Martin Bauer's avatar
Martin Bauer committed
545
546
547
    def subs(self, subs_dict):
        self.lhs = fast_subs(self.lhs, subs_dict)
        self.rhs = fast_subs(self.rhs, subs_dict)
548

549
550
551
552
553
554
555
    def optimize(self, optimizations):
        try:
            from sympy.codegen.rewriting import optimize
            self.rhs = optimize(self.rhs, optimizations)
        except Exception:
            pass

556
557
    @property
    def args(self):
558
        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
559
560

    @property
Martin Bauer's avatar
Martin Bauer committed
561
    def symbols_defined(self):
Martin Bauer's avatar
Martin Bauer committed
562
        if not self._is_declaration:
563
            return set()
Martin Bauer's avatar
Martin Bauer committed
564
        return {self._lhs_symbol}
565
566

    @property
Martin Bauer's avatar
Martin Bauer committed
567
    def undefined_symbols(self):
Stephan Seitz's avatar
Stephan Seitz committed
568
        result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
569
        # Add loop counters if there a field accesses
Martin Bauer's avatar
Martin Bauer committed
570
        loop_counters = set()
571
572
573
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
Martin Bauer's avatar
Martin Bauer committed
574
                    loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
575
        result = {r for r in result if not isinstance(r, TypedImaginaryUnit)}
Martin Bauer's avatar
Martin Bauer committed
576
        result.update(loop_counters)
Martin Bauer's avatar
Martin Bauer committed
577
        result.update(self._lhs_symbol.atoms(sp.Symbol))
578
579
580
        return result

    @property
Martin Bauer's avatar
Martin Bauer committed
581
    def is_declaration(self):
Martin Bauer's avatar
Martin Bauer committed
582
        return self._is_declaration
583
584

    @property
Martin Bauer's avatar
Martin Bauer committed
585
    def is_const(self):
Martin Bauer's avatar
Martin Bauer committed
586
        return self._is_const
587

Jan Hoenig's avatar
Jan Hoenig committed
588
589
    def replace(self, child, replacement):
        if child == self.lhs:
590
591
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
592
593
594
595
596
597
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

598
    def __repr__(self):
599
        return repr(self.lhs) + " ← " + repr(self.rhs)
600

Martin Bauer's avatar
Martin Bauer committed
601
602
603
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
Martin Bauer's avatar
Martin Bauer committed
604
        return "${printed_lhs} \\leftarrow {printed_rhs}$".format(printed_lhs=printed_lhs, printed_rhs=printed_rhs)
Martin Bauer's avatar
Martin Bauer committed
605

606

Martin Bauer's avatar
Martin Bauer committed
607
class ResolvedFieldAccess(sp.Indexed):
Martin Bauer's avatar
Martin Bauer committed
608
    def __new__(cls, base, linearized_index, field, offsets, idx_coordinate_values):
609
        if not isinstance(base, sp.IndexedBase):
610
            assert isinstance(base, TypedSymbol)
611
            base = sp.IndexedBase(base, shape=(1,))
612
            assert isinstance(base.label, TypedSymbol)
Martin Bauer's avatar
Martin Bauer committed
613
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearized_index)
Martin Bauer's avatar
Martin Bauer committed
614
615
        obj.field = field
        obj.offsets = offsets
Martin Bauer's avatar
Martin Bauer committed
616
        obj.idx_coordinate_values = idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
617
618
619
620
621
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
Martin Bauer's avatar
Martin Bauer committed
622
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
623

624
    def fast_subs(self, substitutions, skip=None):
Martin Bauer's avatar
Martin Bauer committed
625
626
627
628
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
629
                                   self.field, self.offsets, self.idx_coordinate_values)
Martin Bauer's avatar
Martin Bauer committed
630
631

    def _hashable_content(self):
Martin Bauer's avatar
Martin Bauer committed
632
        super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
Martin Bauer's avatar
Martin Bauer committed
633
        return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
Martin Bauer's avatar
Martin Bauer committed
634
635

    @property
Martin Bauer's avatar
Martin Bauer committed
636
    def typed_symbol(self):
Martin Bauer's avatar
Martin Bauer committed
637
638
639
640
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
Martin Bauer's avatar
Martin Bauer committed
641
        return "%s (%s)" % (top, self.typed_symbol.dtype)
Martin Bauer's avatar
Martin Bauer committed
642
643

    def __getnewargs__(self):
Martin Bauer's avatar
Martin Bauer committed
644
        return self.base, self.indices[0], self.field, self.offsets, self.idx_coordinate_values
Martin Bauer's avatar
Martin Bauer committed
645
646


647
class TemporaryMemoryAllocation(Node):
Martin Bauer's avatar
Martin Bauer committed
648
649
650
651
652
653
654
655
656
    """Node for temporary memory buffer allocation.

    Always allocates aligned memory.

    Args:
        typed_symbol: symbol used as pointer (has to be typed)
        size: number of elements to allocate
        align_offset: the align_offset's element is aligned
    """
657

Martin Bauer's avatar
Martin Bauer committed
658
    def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
Martin Bauer's avatar
Martin Bauer committed
659
660
        super(TemporaryMemoryAllocation, self).__init__(parent=None)
        self.symbol = typed_symbol
661
        self.size = size
Martin Bauer's avatar
Martin Bauer committed
662
663
        self.headers = ['<stdlib.h>']
        self._align_offset = align_offset
664
665

    @property
Martin Bauer's avatar
Martin Bauer committed
666
667
    def symbols_defined(self):
        return {self.symbol}
668
669

    @property
Martin Bauer's avatar
Martin Bauer committed
670
    def undefined_symbols(self):
671
672
673
674
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
675
676
677

    @property
    def args(self):
678
        return [self.symbol]
679

Martin Bauer's avatar
Martin Bauer committed
680
681
682
683
684
685
    def offset(self, byte_alignment):
        """Number of ELEMENTS to skip for a pointer that is aligned to byte_alignment."""
        np_dtype = self.symbol.dtype.base_type.numpy_dtype
        assert byte_alignment % np_dtype.itemsize == 0
        return -self._align_offset % (byte_alignment / np_dtype.itemsize)

686
687

class TemporaryMemoryFree(Node):
Martin Bauer's avatar
Martin Bauer committed
688
    def __init__(self, alloc_node):
Martin Bauer's avatar
Martin Bauer committed
689
        super(TemporaryMemoryFree, self).__init__(parent=None)
Martin Bauer's avatar
Martin Bauer committed
690
691
692
693
694
695
696
697
        self.alloc_node = alloc_node

    @property
    def symbol(self):
        return self.alloc_node.symbol

    def offset(self, byte_alignment):
        return self.alloc_node.offset(byte_alignment)
698
699

    @property
Martin Bauer's avatar
Martin Bauer committed
700
    def symbols_defined(self):
701
702
703
        return set()

    @property
Martin Bauer's avatar
Martin Bauer committed
704
    def undefined_symbols(self):
705
706
707
708
709
        return set()

    @property
    def args(self):
        return []
Martin Bauer's avatar
Martin Bauer committed
710
711
712
713
714


def early_out(condition):
    from pystencils.cpu.vectorization import vec_all
    return Conditional(vec_all(condition), Block([SkipIteration()]))
715
716


Stephan Seitz's avatar
Stephan Seitz committed
717
718
def get_dummy_symbol(dtype='bool'):
    return TypedSymbol('dummy%s' % uuid.uuid4().hex, create_type(dtype))
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764


class SourceCodeComment(Node):
    def __init__(self, text):
        self.text = text

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return "/* " + self.text + " */"

    def __repr__(self):
        return self.__str__()


class EmptyLine(Node):
    def __init__(self):
        pass

    @property
    def args(self):
        return []

    @property
    def symbols_defined(self):
        return set()

    @property
    def undefined_symbols(self):
        return set()

    def __str__(self):
        return ""

    def __repr__(self):
        return self.__str__()
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789


class ConditionalFieldAccess(sp.Function):
    """
    :class:`pystencils.Field.Access` that is only executed if a certain condition is met.
    Can be used, for instance, for out-of-bound checks.
    """

    def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
        return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))

    @property
    def access(self):
        return self.args[0]

    @property
    def outofbounds_condition(self):
        return self.args[1]

    @property
    def outofbounds_value(self):
        return self.args[2]

    def __getnewargs__(self):
        return self.access, self.outofbounds_condition, self.outofbounds_value