astnodes.py 17.6 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.data_types import TypedSymbol, createType, castFunc
Martin Bauer's avatar
Martin Bauer committed
5
from pystencils.sympyextensions import fast_subs
6
7


8
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
9
10
    """Base class for all AST nodes"""

11
12
13
14
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
15
        """Returns all arguments/children of this node"""
16
17
18
19
        return []

    @property
    def symbolsDefined(self):
20
        """Set of symbols which are defined by this node. """
21
22
23
        return set()

    @property
24
    def undefinedSymbols(self):
25
        """Symbols which are used but are not defined inside this node"""
26
        raise NotImplementedError()
27

28
29
30
31
32
    def subs(self, *args, **kwargs):
        """Inplace! substitute, similar to sympys but modifies ast and returns None"""
        for a in self.args:
            a.subs(*args, **kwargs)

33
34
35
36
    @property
    def func(self):
        return self.__class__

37
    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
38
39
40
        """
        Returns a set of all children which are an instance of the given argType
        """
41
42
43
44
45
46
47
48
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


49
50
51
52
53
54
55
56
57
58
59
60
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        def handleChild(c):
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

        self.trueBlock = handleChild(trueBlock)
        self.falseBlock = handleChild(falseBlock)

    def subs(self, *args, **kwargs):
        self.trueBlock.subs(*args, **kwargs)
        if self.falseBlock:
            self.falseBlock.subs(*args, **kwargs)
        self.conditionExpr = self.conditionExpr.subs(*args, **kwargs)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
94
            result.update(self.falseBlock.undefinedSymbols)
95
96
97
98
99
100
101
102
103
104
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


105
106
107
class KernelFunction(Node):

    class Argument:
108
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
109
            from pystencils.transformations import symbolNameToVariableName
110
            self.name = name
111
            self.dtype = dtype
112
113
114
115
116
117
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
118
            self.symbol = symbol
119
120
121
122
123
124
125
126
127
128
129
130
131
132

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

133
134
135
136
137
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

155
156
157
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

158
    def __init__(self, body, ghostLayers=None, functionName="kernel", backend=""):
159
160
        super(KernelFunction, self).__init__()
        self._body = body
161
        body.parent = self
162
        self._parameters = None
163
        self.functionName = functionName
164
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
165
        self.compile = None
166
        self.ghostLayers = ghostLayers
167
168
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
169
        self.backend = backend
170
171
172
173
174
175

    @property
    def symbolsDefined(self):
        return set()

    @property
176
    def undefinedSymbols(self):
177
178
179
180
181
182
183
184
185
186
187
188
189
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
190
        return [self._body]
191

192
193
194
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
195
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
196

197
    def _updateParameters(self):
198
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
199
200
201
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
202

203
    def __str__(self):
204
        self._updateParameters()
205
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
206
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
207
208
209
210

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
211

212
213
214
215
216

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
217
        self.parent = None
218
219
220
221
222
223
224
225
226
227
228
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
229
    def insertBefore(self, newNode, insertBefore):
230
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
231
        idx = self._nodes.index(insertBefore)
232
233
234

        # move all assignment (definitions to the top)
        if isinstance(newNode, SympyAssignment) and newNode.isDeclaration:
Martin Bauer's avatar
Martin Bauer committed
235
236
237
238
239
240
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
241
242
        self._nodes.insert(idx, newNode)

243
    def append(self, node):
244
245
246
247
248
249
250
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
276
    def undefinedSymbols(self):
277
        result = set()
Martin Bauer's avatar
Martin Bauer committed
278
        defined_symbols = set()
279
        for a in self.args:
280
            result.update(a.undefinedSymbols)
Martin Bauer's avatar
Martin Bauer committed
281
282
            defined_symbols.update(a.symbolsDefined)
        return result - defined_symbols
283

284
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
285
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
286
287

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
288
        return "Block"
289

290
291
292
293
294

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine
Martin Bauer's avatar
Martin Bauer committed
295
296
297
298
299
        for n in listOfNodes:
            n.parent = self

    def __repr__(self):
        return self.pragmaLine
300
301
302
303
304


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
305
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
306
        self.body = body
307
        body.parent = self
308
309
310
311
312
        self.coordinateToLoopOver = coordinateToLoopOver
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
313
314
315
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
316
        result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
317
        result.prefixLines = [l for l in self.prefixLines]
318
319
        return result

320
321
322
323
324
325
326
327
328
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

329
330
    @property
    def args(self):
331
332
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
333
334
            if hasattr(e, "args"):
                result.append(e)
335
336
        return result

337
338
339
340
341
342
343
344
345
346
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

347
348
    @property
    def symbolsDefined(self):
349
350
351
352
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
353
354
        result = self.body.undefinedSymbols
        for possibleSymbol in [self.start, self.stop, self.step]:
355
356
357
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
358

Martin Bauer's avatar
Martin Bauer committed
359
360
361
362
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

363
364
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
365
366
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

367
368
369
370
371
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
372
        if symbol.dtype != createType('int'):
373
374
375
376
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
377
378
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
379
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
380
381
382

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
383
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
384
385
386

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
387
388
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
389
390
391

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
392
        return len(self.atoms(LoopOverCoordinate)) == 0
393

394
    def __str__(self):
395
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
396
397
398
                                                                    self.loopCounterName, self.stop,
                                                                    self.loopCounterName, self.step,
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
399
400

    def __repr__(self):
401
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
402
403
                                                              self.loopCounterName, self.stop,
                                                              self.loopCounterName, self.step)
404

405
406
407
408
409
410

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
411
        isCast = self._lhsSymbol.func == castFunc
412
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
413
414
415
416
417
418
419
420
421
422
423
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
424
        isCast = self._lhsSymbol.func == castFunc
425
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
426
427
            self._isDeclaration = False

428
    def subs(self, *args, **kwargs):
Martin Bauer's avatar
Martin Bauer committed
429
430
        self.lhs = fast_subs(self.lhs, *args, **kwargs)
        self.rhs = fast_subs(self.rhs, *args, **kwargs)
431

432
433
434
435
436
437
438
439
440
441
442
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
443
    def undefinedSymbols(self):
444
        result = self.rhs.atoms(sp.Symbol)
445
446
447
448
449
450
451
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
452
453
454
455
456
457
458
459
460
461
462
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
463
464
    def replace(self, child, replacement):
        if child == self.lhs:
465
466
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
467
468
469
470
471
472
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

473
474
475
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)

Martin Bauer's avatar
Martin Bauer committed
476
477
478
479
480
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
        return f"${printed_lhs} = {printed_rhs}$"

481

Martin Bauer's avatar
Martin Bauer committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

Martin Bauer's avatar
Martin Bauer committed
497
498
499
500
501
    def fast_subs(self, substitutions):
        if self in substitutions:
            return substitutions[self]
        return ResolvedFieldAccess(self.args[0].subs(substitutions),
                                   self.args[1].subs(substitutions),
Martin Bauer's avatar
Martin Bauer committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                                   self.field, self.offsets, self.idxCoordinateValues)

    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

    @property
    def typedSymbol(self):
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
        return "%s (%s)" % (top, self.typedSymbol.dtype)

    def __getnewargs__(self):
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues


520
521
522
523
524
525
526
class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
527
        return set([self.symbol])
528
529

    @property
530
531
532
533
534
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
535
536
537

    @property
    def args(self):
538
        return [self.symbol]
539
540
541
542


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
543
        self.symbol = typedSymbol
544
545
546
547
548
549

    @property
    def symbolsDefined(self):
        return set()

    @property
550
    def undefinedSymbols(self):
551
552
553
554
555
556
        return set()

    @property
    def args(self):
        return []