astnodes.py 17.5 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
Martin Bauer's avatar
Martin Bauer committed
4
from pystencils.data_types import TypedSymbol, createType, castFunc
5
from pystencils.sympyextensions import fastSubs
6
7


8
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
9
10
    """Base class for all AST nodes"""

11
12
13
14
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
15
        """Returns all arguments/children of this node"""
16
17
18
19
        return []

    @property
    def symbolsDefined(self):
20
        """Set of symbols which are defined by this node. """
21
22
23
        return set()

    @property
24
    def undefinedSymbols(self):
25
        """Symbols which are used but are not defined inside this node"""
26
        raise NotImplementedError()
27

28
29
30
31
32
    def subs(self, *args, **kwargs):
        """Inplace! substitute, similar to sympys but modifies ast and returns None"""
        for a in self.args:
            a.subs(*args, **kwargs)

33
34
35
36
    @property
    def func(self):
        return self.__class__

37
    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
38
39
40
        """
        Returns a set of all children which are an instance of the given argType
        """
41
42
43
44
45
46
47
48
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


49
50
51
52
53
54
55
56
57
58
59
60
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        def handleChild(c):
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

        self.trueBlock = handleChild(trueBlock)
        self.falseBlock = handleChild(falseBlock)

    def subs(self, *args, **kwargs):
        self.trueBlock.subs(*args, **kwargs)
        if self.falseBlock:
            self.falseBlock.subs(*args, **kwargs)
        self.conditionExpr = self.conditionExpr.subs(*args, **kwargs)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
94
            result.update(self.falseBlock.undefinedSymbols)
95
96
97
98
99
100
101
102
103
104
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


105
106
107
class KernelFunction(Node):

    class Argument:
108
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
109
            from pystencils.transformations import symbolNameToVariableName
110
            self.name = name
111
            self.dtype = dtype
112
113
114
115
116
117
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
118
            self.symbol = symbol
119
120
121
122
123
124
125
126
127
128
129
130
131
132

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

133
134
135
136
137
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

155
156
157
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

158
    def __init__(self, body, ghostLayers=None, functionName="kernel", backend=""):
159
160
        super(KernelFunction, self).__init__()
        self._body = body
161
        body.parent = self
162
        self._parameters = None
163
        self.functionName = functionName
164
        self._body.parent = self
Martin Bauer's avatar
Martin Bauer committed
165
        self.compile = None
166
        self.ghostLayers = ghostLayers
167
168
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
169
        self.backend = backend
170
171
172
173
174
175

    @property
    def symbolsDefined(self):
        return set()

    @property
176
    def undefinedSymbols(self):
177
178
179
180
181
182
183
184
185
186
187
188
189
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
190
        return [self._body]
191

192
193
194
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
195
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
196

197
    def _updateParameters(self):
198
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
199
200
201
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
202

203
    def __str__(self):
204
        self._updateParameters()
205
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
206
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
207
208
209
210

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
211

212
213
214
215
216

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
217
        self.parent = None
218
219
220
221
222
223
224
225
226
227
228
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
229
    def insertBefore(self, newNode, insertBefore):
230
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
231
        idx = self._nodes.index(insertBefore)
232
233
234

        # move all assignment (definitions to the top)
        if isinstance(newNode, SympyAssignment) and newNode.isDeclaration:
Martin Bauer's avatar
Martin Bauer committed
235
236
237
238
239
240
            while idx > 0:
                pn = self._nodes[idx - 1]
                if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
                    idx -= 1
                else:
                    break
Martin Bauer's avatar
Martin Bauer committed
241
242
        self._nodes.insert(idx, newNode)

243
    def append(self, node):
244
245
246
247
248
249
250
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
276
    def undefinedSymbols(self):
277
        result = set()
278
        definedSymbols = set()
279
        for a in self.args:
280
281
282
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
283

284
    def __str__(self):
Martin Bauer's avatar
Martin Bauer committed
285
        return "Block " + ''.join('{!s}\n'.format(node) for node in self._nodes)
286
287

    def __repr__(self):
Martin Bauer's avatar
Martin Bauer committed
288
        return "Block"
289

290
291
292
293
294

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine
Martin Bauer's avatar
Martin Bauer committed
295
296
297
298
299
        for n in listOfNodes:
            n.parent = self

    def __repr__(self):
        return self.pragmaLine
300
301
302
303
304


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
305
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
306
        self.body = body
307
        body.parent = self
308
309
310
311
312
        self.coordinateToLoopOver = coordinateToLoopOver
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
313
314
315
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
316
        result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
317
        result.prefixLines = [l for l in self.prefixLines]
318
319
        return result

320
321
322
323
324
325
326
327
328
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

329
330
    @property
    def args(self):
331
332
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
333
334
            if hasattr(e, "args"):
                result.append(e)
335
336
        return result

337
338
339
340
341
342
343
344
345
346
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

347
348
    @property
    def symbolsDefined(self):
349
350
351
352
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
353
354
        result = self.body.undefinedSymbols
        for possibleSymbol in [self.start, self.stop, self.step]:
355
356
357
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
358

Martin Bauer's avatar
Martin Bauer committed
359
360
361
362
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

363
364
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
365
366
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

367
368
369
370
371
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
Martin Bauer's avatar
Martin Bauer committed
372
        if symbol.dtype != createType('int'):
373
374
375
376
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
377
378
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
379
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
380
381
382

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
383
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
384
385
386

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
387
388
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
389
390
391

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
392
        return len(self.atoms(LoopOverCoordinate)) == 0
393

394
    def __str__(self):
395
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
396
397
398
                                                                    self.loopCounterName, self.stop,
                                                                    self.loopCounterName, self.step,
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
399
400

    def __repr__(self):
401
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
402
403
                                                              self.loopCounterName, self.stop,
                                                              self.loopCounterName, self.step)
404

405
406
407
408
409
410

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
411
        isCast = self._lhsSymbol.func == castFunc
412
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
413
414
415
416
417
418
419
420
421
422
423
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
424
        isCast = self._lhsSymbol.func == castFunc
425
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
426
427
            self._isDeclaration = False

428
    def subs(self, *args, **kwargs):
429
430
        self.lhs = fastSubs(self.lhs, *args, **kwargs)
        self.rhs = fastSubs(self.rhs, *args, **kwargs)
431

432
433
434
435
436
437
438
439
440
441
442
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
443
    def undefinedSymbols(self):
444
        result = self.rhs.atoms(sp.Symbol)
445
446
447
448
449
450
451
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
452
453
454
455
456
457
458
459
460
461
462
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
463
464
    def replace(self, child, replacement):
        if child == self.lhs:
465
466
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
467
468
469
470
471
472
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

473
474
475
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)

Martin Bauer's avatar
Martin Bauer committed
476
477
478
479
480
    def _repr_html_(self):
        printed_lhs = sp.latex(self.lhs)
        printed_rhs = sp.latex(self.rhs)
        return f"${printed_lhs} = {printed_rhs}$"

481

Martin Bauer's avatar
Martin Bauer committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

    def fastSubs(self, subsDict):
        if self in subsDict:
            return subsDict[self]
        return ResolvedFieldAccess(self.args[0].subs(subsDict),
                                   self.args[1].subs(subsDict),
                                   self.field, self.offsets, self.idxCoordinateValues)

    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

    @property
    def typedSymbol(self):
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
        return "%s (%s)" % (top, self.typedSymbol.dtype)

    def __getnewargs__(self):
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues


520
521
522
523
524
525
526
class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
527
        return set([self.symbol])
528
529

    @property
530
531
532
533
534
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
535
536
537

    @property
    def args(self):
538
        return [self.symbol]
539
540
541
542


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
543
        self.symbol = typedSymbol
544
545
546
547
548
549

    @property
    def symbolsDefined(self):
        return set()

    @property
550
    def undefinedSymbols(self):
551
552
553
554
555
556
        return set()

    @property
    def args(self):
        return []