astnodes.py 17.2 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.data_types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
5
from pystencils.sympyextensions import fastSubs
6
7


Martin Bauer's avatar
Martin Bauer committed
8
9
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
10
11
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
12
13
14
15
16
17
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

18
19
20
21
22
    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

23
24
25
26
27
28
29
    def fastSubs(self, subsDict):
        if self in subsDict:
            return subsDict[self]
        return ResolvedFieldAccess(self.args[0].subs(subsDict),
                                   self.args[1].subs(subsDict),
                                   self.field, self.offsets, self.idxCoordinateValues)

Martin Bauer's avatar
Martin Bauer committed
30
31
32
33
    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

34
35
36
37
38
39
40
41
    @property
    def typedSymbol(self):
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
        return "%s (%s)" % (top, self.typedSymbol.dtype)

Martin Bauer's avatar
Martin Bauer committed
42
    def __getnewargs__(self):
43
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
Martin Bauer's avatar
Martin Bauer committed
44
45


46
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
47
48
    """Base class for all AST nodes"""

49
50
51
52
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
53
        """Returns all arguments/children of this node"""
54
55
56
57
        return []

    @property
    def symbolsDefined(self):
58
        """Set of symbols which are defined by this node. """
59
60
61
        return set()

    @property
62
    def undefinedSymbols(self):
63
        """Symbols which are used but are not defined inside this node"""
64
        raise NotImplementedError()
65

66
67
68
69
70
    def subs(self, *args, **kwargs):
        """Inplace! substitute, similar to sympys but modifies ast and returns None"""
        for a in self.args:
            a.subs(*args, **kwargs)

71
72
73
74
    @property
    def func(self):
        return self.__class__

75
    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
76
77
78
        """
        Returns a set of all children which are an instance of the given argType
        """
79
80
81
82
83
84
85
86
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


87
88
89
90
91
92
93
94
95
96
97
98
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        def handleChild(c):
            if c is None:
                return None
            if not isinstance(c, Block):
                c = Block([c])
            c.parent = self
            return c

        self.trueBlock = handleChild(trueBlock)
        self.falseBlock = handleChild(falseBlock)

    def subs(self, *args, **kwargs):
        self.trueBlock.subs(*args, **kwargs)
        if self.falseBlock:
            self.falseBlock.subs(*args, **kwargs)
        self.conditionExpr = self.conditionExpr.subs(*args, **kwargs)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
132
            result.update(self.falseBlock.undefinedSymbols)
133
134
135
136
137
138
139
140
141
142
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


143
144
145
class KernelFunction(Node):

    class Argument:
146
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
147
            from pystencils.transformations import symbolNameToVariableName
148
            self.name = name
149
            self.dtype = dtype
150
151
152
153
154
155
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
156
            self.symbol = symbol
157
158
159
160
161
162
163
164
165
166
167
168
169
170

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

171
172
173
174
175
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

193
194
195
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

196
    def __init__(self, body, ghostLayers=None, functionName="kernel"):
197
198
        super(KernelFunction, self).__init__()
        self._body = body
199
        body.parent = self
200
        self._parameters = None
201
        self.functionName = functionName
202
        self._body.parent = self
203
        self.ghostLayers = ghostLayers
204
205
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
206
207
208
209
210
211

    @property
    def symbolsDefined(self):
        return set()

    @property
212
    def undefinedSymbols(self):
213
214
215
216
217
218
219
220
221
222
223
224
225
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
226
        return [self._body]
227

228
229
230
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
231
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
232

233
    def _updateParameters(self):
234
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
235
236
237
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
238

239
    def __str__(self):
240
        self._updateParameters()
241
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
242
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
243
244
245
246

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
247

248
249
250
251
252

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
253
        self.parent = None
254
255
256
257
258
259
260
261
262
263
264
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
265
    def insertBefore(self, newNode, insertBefore):
266
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
267
        idx = self._nodes.index(insertBefore)
268
269
270
271
272

        # move all assignment (definitions to the top)
        if isinstance(newNode, SympyAssignment) and newNode.isDeclaration:
            while idx > 0 and not (isinstance(self._nodes[idx-1], SympyAssignment) and self._nodes[idx-1].isDeclaration):
                idx -= 1
Martin Bauer's avatar
Martin Bauer committed
273
274
        self._nodes.insert(idx, newNode)

275
    def append(self, node):
276
277
278
279
280
281
282
        if isinstance(node, list) or isinstance(node, tuple):
            for n in node:
                n.parent = self
                self._nodes.append(n)
        else:
            node.parent = self
            self._nodes.append(node)
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
308
    def undefinedSymbols(self):
309
        result = set()
310
        definedSymbols = set()
311
        for a in self.args:
312
313
314
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
315

316
317
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
318
319

    def __repr__(self):
320
        return ''.join('{!r}'.format(node) for node in self._nodes)
321

322
323
324
325
326
327
328
329
330
331

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
332
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
333
        self.body = body
334
        body.parent = self
335
336
337
338
339
        self.coordinateToLoopOver = coordinateToLoopOver
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
340
341
342
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
343
        result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
344
        result.prefixLines = [l for l in self.prefixLines]
345
346
        return result

347
348
349
350
351
352
353
354
355
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

356
357
    @property
    def args(self):
358
359
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
360
361
            if hasattr(e, "args"):
                result.append(e)
362
363
        return result

364
365
366
367
368
369
370
371
372
373
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement

374
375
    @property
    def symbolsDefined(self):
376
377
378
379
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
380
381
        result = self.body.undefinedSymbols
        for possibleSymbol in [self.start, self.stop, self.step]:
382
383
384
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
385

Martin Bauer's avatar
Martin Bauer committed
386
387
388
389
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

390
391
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
392
393
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

394
395
396
397
398
399
400
401
402
403
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
404
405
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
406
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
407
408
409

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
410
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
411
412
413

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
414
415
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
416
417
418

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
419
        return len(self.atoms(LoopOverCoordinate)) == 0
420

421
    def __str__(self):
422
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
423
424
425
                                                                    self.loopCounterName, self.stop,
                                                                    self.loopCounterName, self.step,
                                                                    ("\t" + "\t".join(str(self.body).splitlines(True))))
426
427

    def __repr__(self):
428
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
429
430
                                                              self.loopCounterName, self.stop,
                                                              self.loopCounterName, self.step)
431

432
433
434
435
436
437

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
438
        isCast = self._lhsSymbol.func == castFunc
439
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
440
441
442
443
444
445
446
447
448
449
450
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
451
        isCast = self._lhsSymbol.func == castFunc
452
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
453
454
            self._isDeclaration = False

455
    def subs(self, *args, **kwargs):
456
457
        self.lhs = fastSubs(self.lhs, *args, **kwargs)
        self.rhs = fastSubs(self.rhs, *args, **kwargs)
458

459
460
461
462
463
464
465
466
467
468
469
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
470
    def undefinedSymbols(self):
471
        result = self.rhs.atoms(sp.Symbol)
472
473
474
475
476
477
478
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
479
480
481
482
483
484
485
486
487
488
489
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
490
491
    def replace(self, child, replacement):
        if child == self.lhs:
492
493
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
494
495
496
497
498
499
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

500
501
502
503
504
505
506
507
508
509
510
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
511
        return set([self.symbol])
512
513

    @property
514
515
516
517
518
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
519
520
521

    @property
    def args(self):
522
        return [self.symbol]
523
524
525
526


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
527
        self.symbol = typedSymbol
528
529
530
531
532
533

    @property
    def symbolsDefined(self):
        return set()

    @property
534
    def undefinedSymbols(self):
535
536
537
538
539
540
        return set()

    @property
    def args(self):
        return []