astnodes.py 16.4 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
Jan Hoenig's avatar
Jan Hoenig committed
4
from pystencils.types import TypedSymbol, createType, get_type_from_sympy
5
6


7
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
8
9
    """Base class for all AST nodes"""

10
11
12
13
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
14
        """Returns all arguments/children of this node"""
15
16
17
18
        return []

    @property
    def symbolsDefined(self):
19
        """Set of symbols which are defined by this node. """
20
21
22
        return set()

    @property
23
    def undefinedSymbols(self):
24
        """Symbols which are used but are not defined inside this node"""
25
        raise NotImplementedError()
26
27

    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
28
29
30
        """
        Returns a set of all children which are an instance of the given argType
        """
31
32
33
34
35
36
37
38
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


80
81
82
class KernelFunction(Node):

    class Argument:
83
84
        def __init__(self, name, dtype, kernelFunctionNode):
            from pystencils.transformations import symbolNameToVariableName
85
            self.name = name
86
            self.dtype = dtype
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

107
108
109
110
111
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

112
113
114
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

115
    def __init__(self, body, fieldsAccessed, functionName="kernel"):
116
117
        super(KernelFunction, self).__init__()
        self._body = body
118
        body.parent = self
119
        self._parameters = None
120
        self.functionName = functionName
121
        self._body.parent = self
122
        self._fieldsAccessed = fieldsAccessed
123
124
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
125
126
127
128
129
130

    @property
    def symbolsDefined(self):
        return set()

    @property
131
    def undefinedSymbols(self):
132
133
134
135
136
137
138
139
140
141
142
143
144
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
145
        return [self._body]
146

147
148
149
150
151
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
        return self._fieldsAccessed

152
    def _updateParameters(self):
153
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
154
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, self) for s in undefinedSymbols]
155
156
157
158
        self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
                                             l.isFieldStrideArgument, l.name),
                              reverse=True)

159
    def __str__(self):
160
        self._updateParameters()
161
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
162
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
163
164
165
166

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
184
    def insertBefore(self, newNode, insertBefore):
185
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
186
187
188
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
217
    def undefinedSymbols(self):
218
        result = set()
219
        definedSymbols = set()
220
        for a in self.args:
221
222
223
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
224

225
226
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
227
228

    def __repr__(self):
229
        return ''.join('{!r}'.format(node) for node in self._nodes)
230

231
232
233
234
235
236
237
238
239
240

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
241
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
242
        self._body = body
243
        body.parent = self
244
        self._coordinateToLoopOver = coordinateToLoopOver
Martin Bauer's avatar
Martin Bauer committed
245
246
247
        self._begin = start
        self._end = stop
        self._increment = step
248
249
250
251
        self._body.parent = self
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
Martin Bauer's avatar
Martin Bauer committed
252
        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._begin, self._end, self._increment)
Martin Bauer's avatar
Martin Bauer committed
253
        result.prefixLines = [l for l in self.prefixLines]
254
255
256
257
258
        return result

    @property
    def args(self):
        result = [self._body]
Martin Bauer's avatar
Martin Bauer committed
259
260
261
        for e in [self._begin, self._end, self._increment]:
            if hasattr(e, "args"):
                result.append(e)
262
263
264
265
266
267
268
        return result

    @property
    def body(self):
        return self._body

    @property
Martin Bauer's avatar
Martin Bauer committed
269
270
271
272
273
274
275
276
277
278
    def start(self):
        return self._begin

    @property
    def stop(self):
        return self._end

    @property
    def step(self):
        return self._increment
279
280
281
282
283
284
285

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

    @property
    def symbolsDefined(self):
286
287
288
289
290
291
292
293
294
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
        result = self._body.undefinedSymbols
        for possibleSymbol in [self._begin, self._end, self._increment]:
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
295

Martin Bauer's avatar
Martin Bauer committed
296
297
298
299
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

300
301
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
302
303
304
305
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
306
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
307
308
309

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
310
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
311
312
313

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
314
315
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
316
317
318

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
319
        return len(self.atoms(LoopOverCoordinate)) == 0
320
321
322
323
324

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

325
    def __str__(self):
326
        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
Michael Kuron's avatar
Michael Kuron committed
327
                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
328
329

    def __repr__(self):
330
        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
331

332
333
334
335
336
337

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
338
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
339
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
340
341
342
343
344
345
346
347
348
349
350
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
351
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
352
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
353
354
355
356
357
358
359
360
361
362
363
364
365
            self._isDeclaration = False

    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
366
    def undefinedSymbols(self):
367
        result = self.rhs.atoms(sp.Symbol)
368
369
370
371
372
373
374
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
375
376
377
378
379
380
381
382
383
384
385
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
386
387
    def replace(self, child, replacement):
        if child == self.lhs:
388
389
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
390
391
392
393
394
395
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

396
397
398
399
400
401
402
403
404
405
406
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
407
        return set([self.symbol])
408
409

    @property
410
411
412
413
414
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
415
416
417

    @property
    def args(self):
418
        return [self.symbol]
419
420
421
422


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
423
        self.symbol = typedSymbol
424
425
426
427
428
429

    @property
    def symbolsDefined(self):
        return set()

    @property
430
    def undefinedSymbols(self):
431
432
433
434
435
436
        return set()

    @property
    def args(self):
        return []

Jan Hoenig's avatar
Jan Hoenig committed
437

Jan Hoenig's avatar
Jan Hoenig committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
# TODO implement defined & undefinedSymbols
class Conversion(Node):
    def __init__(self, child, dtype, parent=None):
        super(Conversion, self).__init__(parent)
        self._args = [child]
        self.dtype = dtype

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
465
        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
Jan Hoenig's avatar
Jan Hoenig committed
466

Jan Hoenig's avatar
Jan Hoenig committed
467
# TODO Pow
Jan Hoenig's avatar
Jan Hoenig committed
468

Jan Hoenig's avatar
Jan Hoenig committed
469
470
471
472

_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}


Jan Hoenig's avatar
Jan Hoenig committed
473
474
475
476
class Expr(Node):
    def __init__(self, args, parent=None):
        super(Expr, self).__init__(parent)
        self._args = list(args)
Jan Hoenig's avatar
Jan Hoenig committed
477
        self.dtype = None
Jan Hoenig's avatar
Jan Hoenig committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505

    @property
    def args(self):
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    def replace(self, child, replacements):
        idx = self.args.index(child)
        del self.args[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self.args = self.args[:idx] + replacements + self.args[idx:]
        else:
            replacements.parent = self
            self.args.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        return set()  # Todo fix for symbol analysis

    @property
    def undefinedSymbols(self):
        return set()  # Todo fix for symbol analysis

Jan Hoenig's avatar
Jan Hoenig committed
506
    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
507
        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
Jan Hoenig's avatar
Jan Hoenig committed
508

Jan Hoenig's avatar
Jan Hoenig committed
509
510
511
512
513
514
515
516
517
518
519
520
521
522

class Mul(Expr):
    pass


class Add(Expr):
    pass


class Pow(Expr):
    pass


class Indexed(Expr):
523
524
525
    def __init__(self, args, base, parent=None):
        super(Indexed, self).__init__(args, parent)
        self.base = base
Jan Hoenig's avatar
Jan Hoenig committed
526
527
        # Get dtype from label, and unpointer it
        self.dtype = createType(base.label.dtype.baseType)
528

Jan Hoenig's avatar
Jan Hoenig committed
529
530
    def __repr__(self):
        return '%s[%s]' % (self.args[0], self.args[1])
Jan Hoenig's avatar
Jan Hoenig committed
531

532

533
534
535
536
537
538
539
540
541
542
543
class PointerArithmetic(Expr):
    def __init__(self, args, pointer, parent=None):
        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
        self.pointer = pointer
        self.offset = args
        self.dtype = pointer.dtype

    def __repr__(self):
        return '*(%s + %s)' % (self.pointer, self.args)


Jan Hoenig's avatar
Jan Hoenig committed
544
class Number(Node, sp.AtomicExpr):
Jan Hoenig's avatar
Jan Hoenig committed
545
546
    def __init__(self, number, parent=None):
        super(Number, self).__init__(parent)
Jan Hoenig's avatar
Jan Hoenig committed
547
548
549

        self.dtype, self.value = get_type_from_sympy(number)
        self._args = tuple()
Jan Hoenig's avatar
Jan Hoenig committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
567
        return repr(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
568

Jan Hoenig's avatar
Jan Hoenig committed
569
570
571
572
573
    def __float__(self):
        return float(self.value)

    def __int__(self):
        return int(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
574

Jan Hoenig's avatar
Jan Hoenig committed
575