astnodes.py 16.8 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
5
6


7
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
8
9
    """Base class for all AST nodes"""

10
11
12
13
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
14
        """Returns all arguments/children of this node"""
15
16
17
18
        return []

    @property
    def symbolsDefined(self):
19
        """Set of symbols which are defined by this node. """
20
21
22
        return set()

    @property
23
    def undefinedSymbols(self):
24
        """Symbols which are used but are not defined inside this node"""
25
        raise NotImplementedError()
26
27

    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
28
29
30
        """
        Returns a set of all children which are an instance of the given argType
        """
31
32
33
34
35
36
37
38
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


80
81
82
class KernelFunction(Node):

    class Argument:
83
84
        def __init__(self, name, dtype, kernelFunctionNode):
            from pystencils.transformations import symbolNameToVariableName
85
            self.name = name
86
            self.dtype = dtype
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

107
108
109
110
111
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

112
113
114
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

115
    def __init__(self, body, fieldsAccessed, functionName="kernel"):
116
117
        super(KernelFunction, self).__init__()
        self._body = body
118
        body.parent = self
119
        self._parameters = None
120
        self.functionName = functionName
121
        self._body.parent = self
122
        self._fieldsAccessed = fieldsAccessed
123
124
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
125
126
127
128
129
130

    @property
    def symbolsDefined(self):
        return set()

    @property
131
    def undefinedSymbols(self):
132
133
134
135
136
137
138
139
140
141
142
143
144
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
145
        return [self._body]
146

147
148
149
150
151
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
        return self._fieldsAccessed

152
    def _updateParameters(self):
153
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
154
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, self) for s in undefinedSymbols]
155
156
157
158
        self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
                                             l.isFieldStrideArgument, l.name),
                              reverse=True)

159
    def __str__(self):
160
        self._updateParameters()
161
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
162
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
163
164
165
166

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
184
    def insertBefore(self, newNode, insertBefore):
185
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
186
187
188
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
217
    def undefinedSymbols(self):
218
        result = set()
219
        definedSymbols = set()
220
        for a in self.args:
221
222
223
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
224

225
226
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
227
228

    def __repr__(self):
229
        return ''.join('{!r}'.format(node) for node in self._nodes)
230

231
232
233
234
235
236
237
238
239
240

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
241
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
242
        self._body = body
243
        body.parent = self
244
        self._coordinateToLoopOver = coordinateToLoopOver
Martin Bauer's avatar
Martin Bauer committed
245
246
247
        self._begin = start
        self._end = stop
        self._increment = step
248
249
250
251
        self._body.parent = self
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
Martin Bauer's avatar
Martin Bauer committed
252
        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._begin, self._end, self._increment)
Martin Bauer's avatar
Martin Bauer committed
253
        result.prefixLines = [l for l in self.prefixLines]
254
255
256
257
258
        return result

    @property
    def args(self):
        result = [self._body]
Martin Bauer's avatar
Martin Bauer committed
259
260
261
        for e in [self._begin, self._end, self._increment]:
            if hasattr(e, "args"):
                result.append(e)
262
263
264
265
266
267
268
        return result

    @property
    def body(self):
        return self._body

    @property
Martin Bauer's avatar
Martin Bauer committed
269
270
271
272
273
274
275
276
277
278
    def start(self):
        return self._begin

    @property
    def stop(self):
        return self._end

    @property
    def step(self):
        return self._increment
279
280
281
282
283
284
285

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

    @property
    def symbolsDefined(self):
286
287
288
289
290
291
292
293
294
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
        result = self._body.undefinedSymbols
        for possibleSymbol in [self._begin, self._end, self._increment]:
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
295

Martin Bauer's avatar
Martin Bauer committed
296
297
298
299
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

300
301
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
302
303
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

304
305
306
307
308
309
310
311
312
313
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
314
315
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
316
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
317
318
319

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
320
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
321
322
323

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
324
325
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
326
327
328

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
329
        return len(self.atoms(LoopOverCoordinate)) == 0
330
331
332
333
334

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

335
    def __str__(self):
336
        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
Michael Kuron's avatar
Michael Kuron committed
337
                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
338
339

    def __repr__(self):
340
        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
341

342
343
344
345
346
347

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
348
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
349
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
350
351
352
353
354
355
356
357
358
359
360
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
361
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
362
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
363
364
365
366
367
368
369
370
371
372
373
374
375
            self._isDeclaration = False

    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
376
    def undefinedSymbols(self):
377
        result = self.rhs.atoms(sp.Symbol)
378
379
380
381
382
383
384
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
385
386
387
388
389
390
391
392
393
394
395
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
396
397
    def replace(self, child, replacement):
        if child == self.lhs:
398
399
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
400
401
402
403
404
405
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

406
407
408
409
410
411
412
413
414
415
416
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
417
        return set([self.symbol])
418
419

    @property
420
421
422
423
424
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
425
426
427

    @property
    def args(self):
428
        return [self.symbol]
429
430
431
432


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
433
        self.symbol = typedSymbol
434
435
436
437
438
439

    @property
    def symbolsDefined(self):
        return set()

    @property
440
    def undefinedSymbols(self):
441
442
443
444
445
446
        return set()

    @property
    def args(self):
        return []

Jan Hoenig's avatar
Jan Hoenig committed
447

Jan Hoenig's avatar
Jan Hoenig committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
# TODO implement defined & undefinedSymbols
class Conversion(Node):
    def __init__(self, child, dtype, parent=None):
        super(Conversion, self).__init__(parent)
        self._args = [child]
        self.dtype = dtype

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
475
        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
Jan Hoenig's avatar
Jan Hoenig committed
476

Jan Hoenig's avatar
Jan Hoenig committed
477
# TODO Pow
Jan Hoenig's avatar
Jan Hoenig committed
478

Jan Hoenig's avatar
Jan Hoenig committed
479
480
481
482

_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}


Jan Hoenig's avatar
Jan Hoenig committed
483
484
485
486
class Expr(Node):
    def __init__(self, args, parent=None):
        super(Expr, self).__init__(parent)
        self._args = list(args)
Jan Hoenig's avatar
Jan Hoenig committed
487
        self.dtype = None
Jan Hoenig's avatar
Jan Hoenig committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

    @property
    def args(self):
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    def replace(self, child, replacements):
        idx = self.args.index(child)
        del self.args[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self.args = self.args[:idx] + replacements + self.args[idx:]
        else:
            replacements.parent = self
            self.args.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        return set()  # Todo fix for symbol analysis

    @property
    def undefinedSymbols(self):
        return set()  # Todo fix for symbol analysis

Jan Hoenig's avatar
Jan Hoenig committed
516
    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
517
        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
Jan Hoenig's avatar
Jan Hoenig committed
518

Jan Hoenig's avatar
Jan Hoenig committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532

class Mul(Expr):
    pass


class Add(Expr):
    pass


class Pow(Expr):
    pass


class Indexed(Expr):
533
534
535
    def __init__(self, args, base, parent=None):
        super(Indexed, self).__init__(args, parent)
        self.base = base
Jan Hoenig's avatar
Jan Hoenig committed
536
537
        # Get dtype from label, and unpointer it
        self.dtype = createType(base.label.dtype.baseType)
538

Jan Hoenig's avatar
Jan Hoenig committed
539
540
    def __repr__(self):
        return '%s[%s]' % (self.args[0], self.args[1])
Jan Hoenig's avatar
Jan Hoenig committed
541

542

543
544
545
546
547
548
549
550
551
552
553
class PointerArithmetic(Expr):
    def __init__(self, args, pointer, parent=None):
        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
        self.pointer = pointer
        self.offset = args
        self.dtype = pointer.dtype

    def __repr__(self):
        return '*(%s + %s)' % (self.pointer, self.args)


Jan Hoenig's avatar
Jan Hoenig committed
554
class Number(Node, sp.AtomicExpr):
Jan Hoenig's avatar
Jan Hoenig committed
555
556
    def __init__(self, number, parent=None):
        super(Number, self).__init__(parent)
Jan Hoenig's avatar
Jan Hoenig committed
557
558
559

        self.dtype, self.value = get_type_from_sympy(number)
        self._args = tuple()
Jan Hoenig's avatar
Jan Hoenig committed
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
577
        return repr(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
578

Jan Hoenig's avatar
Jan Hoenig committed
579
580
581
582
583
    def __float__(self):
        return float(self.value)

    def __int__(self):
        return int(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
584

Jan Hoenig's avatar
Jan Hoenig committed
585