astnodes.py 17.4 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
5
6


Martin Bauer's avatar
Martin Bauer committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

    def __getnewargs__(self):
        return self.name, self.indices[0], self.field, self.offsets, self.idxCoordinateValues


23
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
24
25
    """Base class for all AST nodes"""

26
27
28
29
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
30
        """Returns all arguments/children of this node"""
31
32
33
34
        return []

    @property
    def symbolsDefined(self):
35
        """Set of symbols which are defined by this node. """
36
37
38
        return set()

    @property
39
    def undefinedSymbols(self):
40
        """Symbols which are used but are not defined inside this node"""
41
        raise NotImplementedError()
42
43

    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
44
45
46
        """
        Returns a set of all children which are an instance of the given argType
        """
47
48
49
50
51
52
53
54
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


96
97
98
class KernelFunction(Node):

    class Argument:
99
100
        def __init__(self, name, dtype, kernelFunctionNode):
            from pystencils.transformations import symbolNameToVariableName
101
            self.name = name
102
            self.dtype = dtype
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

123
124
125
126
127
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

128
129
130
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

Martin Bauer's avatar
Martin Bauer committed
131
    def __init__(self, body, functionName="kernel"):
132
133
        super(KernelFunction, self).__init__()
        self._body = body
134
        body.parent = self
135
        self._parameters = None
136
        self.functionName = functionName
137
        self._body.parent = self
138
139
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
140
141
142
143
144
145

    @property
    def symbolsDefined(self):
        return set()

    @property
146
    def undefinedSymbols(self):
147
148
149
150
151
152
153
154
155
156
157
158
159
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
160
        return [self._body]
161

162
163
164
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
165
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
166

167
    def _updateParameters(self):
168
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
169
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, self) for s in undefinedSymbols]
170
171
172
173
        self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
                                             l.isFieldStrideArgument, l.name),
                              reverse=True)

174
    def __str__(self):
175
        self._updateParameters()
176
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
177
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
178
179
180
181

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
199
    def insertBefore(self, newNode, insertBefore):
200
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
201
202
203
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
232
    def undefinedSymbols(self):
233
        result = set()
234
        definedSymbols = set()
235
        for a in self.args:
236
237
238
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
239

240
241
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
242
243

    def __repr__(self):
244
        return ''.join('{!r}'.format(node) for node in self._nodes)
245

246
247
248
249
250
251
252
253
254
255

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
256
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
257
        self._body = body
258
        body.parent = self
259
        self._coordinateToLoopOver = coordinateToLoopOver
Martin Bauer's avatar
Martin Bauer committed
260
261
262
        self._begin = start
        self._end = stop
        self._increment = step
263
264
265
266
        self._body.parent = self
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
Martin Bauer's avatar
Martin Bauer committed
267
        result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._begin, self._end, self._increment)
Martin Bauer's avatar
Martin Bauer committed
268
        result.prefixLines = [l for l in self.prefixLines]
269
270
271
272
273
        return result

    @property
    def args(self):
        result = [self._body]
Martin Bauer's avatar
Martin Bauer committed
274
275
276
        for e in [self._begin, self._end, self._increment]:
            if hasattr(e, "args"):
                result.append(e)
277
278
279
280
281
282
283
        return result

    @property
    def body(self):
        return self._body

    @property
Martin Bauer's avatar
Martin Bauer committed
284
285
286
287
288
289
290
291
292
293
    def start(self):
        return self._begin

    @property
    def stop(self):
        return self._end

    @property
    def step(self):
        return self._increment
294
295
296
297
298
299
300

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

    @property
    def symbolsDefined(self):
301
302
303
304
305
306
307
308
309
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
        result = self._body.undefinedSymbols
        for possibleSymbol in [self._begin, self._end, self._increment]:
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
310

Martin Bauer's avatar
Martin Bauer committed
311
312
313
314
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

315
316
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
317
318
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

319
320
321
322
323
324
325
326
327
328
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
329
330
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
331
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
332
333
334

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
335
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
336
337
338

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
339
340
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
341
342
343

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
344
        return len(self.atoms(LoopOverCoordinate)) == 0
345
346
347
348
349

    @property
    def coordinateToLoopOver(self):
        return self._coordinateToLoopOver

350
    def __str__(self):
351
        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
Michael Kuron's avatar
Michael Kuron committed
352
                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
353
354

    def __repr__(self):
355
        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
356

357
358
359
360
361
362

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
363
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
364
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
365
366
367
368
369
370
371
372
373
374
375
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
376
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
377
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
378
379
380
381
382
383
384
385
386
387
388
389
390
            self._isDeclaration = False

    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
391
    def undefinedSymbols(self):
392
        result = self.rhs.atoms(sp.Symbol)
393
394
395
396
397
398
399
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
400
401
402
403
404
405
406
407
408
409
410
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
411
412
    def replace(self, child, replacement):
        if child == self.lhs:
413
414
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
415
416
417
418
419
420
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

421
422
423
424
425
426
427
428
429
430
431
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
432
        return set([self.symbol])
433
434

    @property
435
436
437
438
439
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
440
441
442

    @property
    def args(self):
443
        return [self.symbol]
444
445
446
447


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
448
        self.symbol = typedSymbol
449
450
451
452
453
454

    @property
    def symbolsDefined(self):
        return set()

    @property
455
    def undefinedSymbols(self):
456
457
458
459
460
461
        return set()

    @property
    def args(self):
        return []

Jan Hoenig's avatar
Jan Hoenig committed
462

Jan Hoenig's avatar
Jan Hoenig committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
# TODO implement defined & undefinedSymbols
class Conversion(Node):
    def __init__(self, child, dtype, parent=None):
        super(Conversion, self).__init__(parent)
        self._args = [child]
        self.dtype = dtype

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
490
        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
Jan Hoenig's avatar
Jan Hoenig committed
491

Jan Hoenig's avatar
Jan Hoenig committed
492
# TODO Pow
Jan Hoenig's avatar
Jan Hoenig committed
493

Jan Hoenig's avatar
Jan Hoenig committed
494
495
496
497

_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}


Jan Hoenig's avatar
Jan Hoenig committed
498
499
500
501
class Expr(Node):
    def __init__(self, args, parent=None):
        super(Expr, self).__init__(parent)
        self._args = list(args)
Jan Hoenig's avatar
Jan Hoenig committed
502
        self.dtype = None
Jan Hoenig's avatar
Jan Hoenig committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    @property
    def args(self):
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    def replace(self, child, replacements):
        idx = self.args.index(child)
        del self.args[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self.args = self.args[:idx] + replacements + self.args[idx:]
        else:
            replacements.parent = self
            self.args.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        return set()  # Todo fix for symbol analysis

    @property
    def undefinedSymbols(self):
        return set()  # Todo fix for symbol analysis

Jan Hoenig's avatar
Jan Hoenig committed
531
    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
532
        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
Jan Hoenig's avatar
Jan Hoenig committed
533

Jan Hoenig's avatar
Jan Hoenig committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547

class Mul(Expr):
    pass


class Add(Expr):
    pass


class Pow(Expr):
    pass


class Indexed(Expr):
548
549
550
    def __init__(self, args, base, parent=None):
        super(Indexed, self).__init__(args, parent)
        self.base = base
Jan Hoenig's avatar
Jan Hoenig committed
551
552
        # Get dtype from label, and unpointer it
        self.dtype = createType(base.label.dtype.baseType)
553

Jan Hoenig's avatar
Jan Hoenig committed
554
555
    def __repr__(self):
        return '%s[%s]' % (self.args[0], self.args[1])
Jan Hoenig's avatar
Jan Hoenig committed
556

557

558
559
560
561
562
563
564
565
566
567
568
class PointerArithmetic(Expr):
    def __init__(self, args, pointer, parent=None):
        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
        self.pointer = pointer
        self.offset = args
        self.dtype = pointer.dtype

    def __repr__(self):
        return '*(%s + %s)' % (self.pointer, self.args)


Jan Hoenig's avatar
Jan Hoenig committed
569
class Number(Node, sp.AtomicExpr):
Jan Hoenig's avatar
Jan Hoenig committed
570
571
    def __init__(self, number, parent=None):
        super(Number, self).__init__(parent)
Jan Hoenig's avatar
Jan Hoenig committed
572
573
574

        self.dtype, self.value = get_type_from_sympy(number)
        self._args = tuple()
Jan Hoenig's avatar
Jan Hoenig committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
592
        return repr(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
593

Jan Hoenig's avatar
Jan Hoenig committed
594
595
596
597
598
    def __float__(self):
        return float(self.value)

    def __int__(self):
        return int(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
599

Jan Hoenig's avatar
Jan Hoenig committed
600