astnodes.py 17.8 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
5
6


Martin Bauer's avatar
Martin Bauer committed
7
8
9
10
11
12
13
14
15
16
17
18
19
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

    def __getnewargs__(self):
20
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
Martin Bauer's avatar
Martin Bauer committed
21
22


23
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
24
25
    """Base class for all AST nodes"""

26
27
28
29
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
30
        """Returns all arguments/children of this node"""
31
32
33
34
        return []

    @property
    def symbolsDefined(self):
35
        """Set of symbols which are defined by this node. """
36
37
38
        return set()

    @property
39
    def undefinedSymbols(self):
40
        """Symbols which are used but are not defined inside this node"""
41
        raise NotImplementedError()
42
43

    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
44
45
46
        """
        Returns a set of all children which are an instance of the given argType
        """
47
48
49
50
51
52
53
54
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


96
97
98
class KernelFunction(Node):

    class Argument:
99
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
100
            from pystencils.transformations import symbolNameToVariableName
101
            self.name = name
102
            self.dtype = dtype
103
104
105
106
107
108
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
109
            self.symbol = symbol
110
111
112
113
114
115
116
117
118
119
120
121
122
123

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

124
125
126
127
128
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

146
147
148
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

Martin Bauer's avatar
Martin Bauer committed
149
    def __init__(self, body, functionName="kernel"):
150
151
        super(KernelFunction, self).__init__()
        self._body = body
152
        body.parent = self
153
        self._parameters = None
154
        self.functionName = functionName
155
        self._body.parent = self
156
157
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
158
159
160
161
162
163

    @property
    def symbolsDefined(self):
        return set()

    @property
164
    def undefinedSymbols(self):
165
166
167
168
169
170
171
172
173
174
175
176
177
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
178
        return [self._body]
179

180
181
182
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
183
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
184

185
    def _updateParameters(self):
186
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
187
188
189
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
190

191
    def __str__(self):
192
        self._updateParameters()
193
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
194
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
195
196
197
198

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
216
    def insertBefore(self, newNode, insertBefore):
217
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
218
219
220
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
249
    def undefinedSymbols(self):
250
        result = set()
251
        definedSymbols = set()
252
        for a in self.args:
253
254
255
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
256

257
258
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
259
260

    def __repr__(self):
261
        return ''.join('{!r}'.format(node) for node in self._nodes)
262

263
264
265
266
267
268
269
270
271
272

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
273
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
274
        self._body = body
275
        body.parent = self
276
        self._coordinateToLoopOver = coordinateToLoopOver
Martin Bauer's avatar
Martin Bauer committed
277
278
279
        self._begin = start
        self._end = stop
        self._increment = step
280
281
282
283
        self._body.parent = self
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
Martin Bauer's avatar
Martin Bauer committed
284
        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._begin, self._end, self._increment)
Martin Bauer's avatar
Martin Bauer committed
285
        result.prefixLines = [l for l in self.prefixLines]
286
287
288
289
290
        return result

    @property
    def args(self):
        result = [self._body]
Martin Bauer's avatar
Martin Bauer committed
291
292
293
        for e in [self._begin, self._end, self._increment]:
            if hasattr(e, "args"):
                result.append(e)
294
295
296
297
298
299
300
        return result

    @property
    def body(self):
        return self._body

    @property
Martin Bauer's avatar
Martin Bauer committed
301
302
303
304
305
306
307
308
309
310
    def start(self):
        return self._begin

    @property
    def stop(self):
        return self._end

    @property
    def step(self):
        return self._increment
311
312
313
314
315
316
317

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

    @property
    def symbolsDefined(self):
318
319
320
321
322
323
324
325
326
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
        result = self._body.undefinedSymbols
        for possibleSymbol in [self._begin, self._end, self._increment]:
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
327

Martin Bauer's avatar
Martin Bauer committed
328
329
330
331
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

332
333
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
334
335
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

336
337
338
339
340
341
342
343
344
345
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
346
347
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
348
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
349
350
351

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
352
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
353
354
355

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
356
357
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
358
359
360

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
361
        return len(self.atoms(LoopOverCoordinate)) == 0
362
363
364
365
366

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

367
    def __str__(self):
368
        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
Michael Kuron's avatar
Michael Kuron committed
369
                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
370
371

    def __repr__(self):
372
        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
373

374
375
376
377
378
379

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
380
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
381
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
382
383
384
385
386
387
388
389
390
391
392
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
393
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
394
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
395
396
397
398
399
400
401
402
403
404
405
406
407
            self._isDeclaration = False

    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
408
    def undefinedSymbols(self):
409
        result = self.rhs.atoms(sp.Symbol)
410
411
412
413
414
415
416
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
417
418
419
420
421
422
423
424
425
426
427
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
428
429
    def replace(self, child, replacement):
        if child == self.lhs:
430
431
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
432
433
434
435
436
437
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

438
439
440
441
442
443
444
445
446
447
448
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
449
        return set([self.symbol])
450
451

    @property
452
453
454
455
456
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
457
458
459

    @property
    def args(self):
460
        return [self.symbol]
461
462
463
464


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
465
        self.symbol = typedSymbol
466
467
468
469
470
471

    @property
    def symbolsDefined(self):
        return set()

    @property
472
    def undefinedSymbols(self):
473
474
475
476
477
478
        return set()

    @property
    def args(self):
        return []

Jan Hoenig's avatar
Jan Hoenig committed
479

Jan Hoenig's avatar
Jan Hoenig committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# TODO implement defined & undefinedSymbols
class Conversion(Node):
    def __init__(self, child, dtype, parent=None):
        super(Conversion, self).__init__(parent)
        self._args = [child]
        self.dtype = dtype

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
507
        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
Jan Hoenig's avatar
Jan Hoenig committed
508

Jan Hoenig's avatar
Jan Hoenig committed
509
# TODO Pow
Jan Hoenig's avatar
Jan Hoenig committed
510

Jan Hoenig's avatar
Jan Hoenig committed
511
512
513
514

_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}


Jan Hoenig's avatar
Jan Hoenig committed
515
516
517
518
class Expr(Node):
    def __init__(self, args, parent=None):
        super(Expr, self).__init__(parent)
        self._args = list(args)
Jan Hoenig's avatar
Jan Hoenig committed
519
        self.dtype = None
Jan Hoenig's avatar
Jan Hoenig committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547

    @property
    def args(self):
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    def replace(self, child, replacements):
        idx = self.args.index(child)
        del self.args[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self.args = self.args[:idx] + replacements + self.args[idx:]
        else:
            replacements.parent = self
            self.args.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        return set()  # Todo fix for symbol analysis

    @property
    def undefinedSymbols(self):
        return set()  # Todo fix for symbol analysis

Jan Hoenig's avatar
Jan Hoenig committed
548
    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
549
        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
Jan Hoenig's avatar
Jan Hoenig committed
550

Jan Hoenig's avatar
Jan Hoenig committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564

class Mul(Expr):
    pass


class Add(Expr):
    pass


class Pow(Expr):
    pass


class Indexed(Expr):
565
566
567
    def __init__(self, args, base, parent=None):
        super(Indexed, self).__init__(args, parent)
        self.base = base
Jan Hoenig's avatar
Jan Hoenig committed
568
569
        # Get dtype from label, and unpointer it
        self.dtype = createType(base.label.dtype.baseType)
570

Jan Hoenig's avatar
Jan Hoenig committed
571
572
    def __repr__(self):
        return '%s[%s]' % (self.args[0], self.args[1])
Jan Hoenig's avatar
Jan Hoenig committed
573

574

575
576
577
578
579
580
581
582
583
584
585
class PointerArithmetic(Expr):
    def __init__(self, args, pointer, parent=None):
        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
        self.pointer = pointer
        self.offset = args
        self.dtype = pointer.dtype

    def __repr__(self):
        return '*(%s + %s)' % (self.pointer, self.args)


Jan Hoenig's avatar
Jan Hoenig committed
586
class Number(Node, sp.AtomicExpr):
Jan Hoenig's avatar
Jan Hoenig committed
587
588
    def __init__(self, number, parent=None):
        super(Number, self).__init__(parent)
Jan Hoenig's avatar
Jan Hoenig committed
589
590
591

        self.dtype, self.value = get_type_from_sympy(number)
        self._args = tuple()
Jan Hoenig's avatar
Jan Hoenig committed
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
609
        return repr(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
610

Jan Hoenig's avatar
Jan Hoenig committed
611
612
613
614
615
    def __float__(self):
        return float(self.value)

    def __int__(self):
        return int(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
616

Jan Hoenig's avatar
Jan Hoenig committed
617