astnodes.py 16 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.data_types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
5
from pystencils.sympyextensions import fastSubs
6
7


Martin Bauer's avatar
Martin Bauer committed
8
9
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
10
11
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
12
13
14
15
16
17
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

18
19
20
21
22
    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

Martin Bauer's avatar
Martin Bauer committed
23
24
25
26
    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

27
28
29
30
31
32
33
34
    @property
    def typedSymbol(self):
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
        return "%s (%s)" % (top, self.typedSymbol.dtype)

Martin Bauer's avatar
Martin Bauer committed
35
    def __getnewargs__(self):
36
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
Martin Bauer's avatar
Martin Bauer committed
37
38


39
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
40
41
    """Base class for all AST nodes"""

42
43
44
45
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
46
        """Returns all arguments/children of this node"""
47
48
49
50
        return []

    @property
    def symbolsDefined(self):
51
        """Set of symbols which are defined by this node. """
52
53
54
        return set()

    @property
55
    def undefinedSymbols(self):
56
        """Symbols which are used but are not defined inside this node"""
57
        raise NotImplementedError()
58

59
60
61
62
63
    def subs(self, *args, **kwargs):
        """Inplace! substitute, similar to sympys but modifies ast and returns None"""
        for a in self.args:
            a.subs(*args, **kwargs)

64
65
66
67
    @property
    def func(self):
        return self.__class__

68
    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
69
70
71
        """
        Returns a set of all children which are an instance of the given argType
        """
72
73
74
75
76
77
78
79
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


121
122
123
class KernelFunction(Node):

    class Argument:
124
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
125
            from pystencils.transformations import symbolNameToVariableName
126
            self.name = name
127
            self.dtype = dtype
128
129
130
131
132
133
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
134
            self.symbol = symbol
135
136
137
138
139
140
141
142
143
144
145
146
147
148

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

149
150
151
152
153
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

171
172
173
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

174
    def __init__(self, body, ghostLayers=None, functionName="kernel"):
175
176
        super(KernelFunction, self).__init__()
        self._body = body
177
        body.parent = self
178
        self._parameters = None
179
        self.functionName = functionName
180
        self._body.parent = self
181
        self.ghostLayers = ghostLayers
182
183
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
184
185
186
187
188
189

    @property
    def symbolsDefined(self):
        return set()

    @property
190
    def undefinedSymbols(self):
191
192
193
194
195
196
197
198
199
200
201
202
203
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
204
        return [self._body]
205

206
207
208
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
209
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
210

211
    def _updateParameters(self):
212
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
213
214
215
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
216

217
    def __str__(self):
218
        self._updateParameters()
219
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
220
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
221
222
223
224

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
225

226
227
228
229
230

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
231
        self.parent = None
232
233
234
235
236
237
238
239
240
241
242
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
243
    def insertBefore(self, newNode, insertBefore):
244
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
245
246
247
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
276
    def undefinedSymbols(self):
277
        result = set()
278
        definedSymbols = set()
279
        for a in self.args:
280
281
282
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
283

284
285
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
286
287

    def __repr__(self):
288
        return ''.join('{!r}'.format(node) for node in self._nodes)
289

290
291
292
293
294
295
296
297
298
299

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
300
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
301
        self.body = body
302
        body.parent = self
303
304
305
306
307
        self.coordinateToLoopOver = coordinateToLoopOver
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
308
309
310
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
311
        result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
312
        result.prefixLines = [l for l in self.prefixLines]
313
314
        return result

315
316
317
318
319
320
321
322
323
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

324
325
    @property
    def args(self):
326
327
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
328
329
            if hasattr(e, "args"):
                result.append(e)
330
331
        return result

332
333
334
335
336
337
338
339
340
341
342
    def replace(self, child, replacement):
        if child == self.body:
            self.body = replacement
        elif child == self.start:
            self.start = replacement
        elif child == self.step:
            self.step = replacement
        elif child == self.stop:
            self.stop = replacement


343
344
    @property
    def symbolsDefined(self):
345
346
347
348
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
349
350
        result = self.body.undefinedSymbols
        for possibleSymbol in [self.start, self.stop, self.step]:
351
352
353
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
354

Martin Bauer's avatar
Martin Bauer committed
355
356
357
358
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

359
360
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
361
362
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

363
364
365
366
367
368
369
370
371
372
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
373
374
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
375
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
376
377
378

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
379
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
380
381
382

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
383
384
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
385
386
387

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
388
        return len(self.atoms(LoopOverCoordinate)) == 0
389

390
    def __str__(self):
391
392
393
394
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loopCounterName, self.start,
                                                                     self.loopCounterName, self.stop,
                                                                     self.loopCounterName, self.step,
                                                                     ("\t" + "\t".join(str(self.body).splitlines(True))))
395
396

    def __repr__(self):
397
398
399
        return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loopCounterName, self.start,
                                                               self.loopCounterName, self.stop,
                                                               self.loopCounterName, self.step)
400

401
402
403
404
405
406

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
407
        isCast = self._lhsSymbol.func == castFunc
408
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
409
410
411
412
413
414
415
416
417
418
419
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
420
        isCast = self._lhsSymbol.func == castFunc
421
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
422
423
            self._isDeclaration = False

424
    def subs(self, *args, **kwargs):
425
426
        self.lhs = fastSubs(self.lhs, *args, **kwargs)
        self.rhs = fastSubs(self.rhs, *args, **kwargs)
427

428
429
430
431
432
433
434
435
436
437
438
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
439
    def undefinedSymbols(self):
440
        result = self.rhs.atoms(sp.Symbol)
441
442
443
444
445
446
447
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
448
449
450
451
452
453
454
455
456
457
458
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
459
460
    def replace(self, child, replacement):
        if child == self.lhs:
461
462
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
463
464
465
466
467
468
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

469
470
471
472
473
474
475
476
477
478
479
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
480
        return set([self.symbol])
481
482

    @property
483
484
485
486
487
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
488
489
490

    @property
    def args(self):
491
        return [self.symbol]
492
493
494
495


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
496
        self.symbol = typedSymbol
497
498
499
500
501
502

    @property
    def symbolsDefined(self):
        return set()

    @property
503
    def undefinedSymbols(self):
504
505
506
507
508
509
        return set()

    @property
    def args(self):
        return []