astnodes.py 18.6 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.data_types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
5
from pystencils.sympyextensions import fastSubs
6
7


Martin Bauer's avatar
Martin Bauer committed
8
9
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
10
11
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
12
13
14
15
16
17
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

18
19
20
21
22
    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

Martin Bauer's avatar
Martin Bauer committed
23
24
25
26
    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

27
28
29
30
31
32
33
34
    @property
    def typedSymbol(self):
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
        return "%s (%s)" % (top, self.typedSymbol.dtype)

Martin Bauer's avatar
Martin Bauer committed
35
    def __getnewargs__(self):
36
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
Martin Bauer's avatar
Martin Bauer committed
37
38


39
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
40
41
    """Base class for all AST nodes"""

42
43
44
45
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
46
        """Returns all arguments/children of this node"""
47
48
49
50
        return []

    @property
    def symbolsDefined(self):
51
        """Set of symbols which are defined by this node. """
52
53
54
        return set()

    @property
55
    def undefinedSymbols(self):
56
        """Symbols which are used but are not defined inside this node"""
57
        raise NotImplementedError()
58

59
60
61
62
63
    def subs(self, *args, **kwargs):
        """Inplace! substitute, similar to sympys but modifies ast and returns None"""
        for a in self.args:
            a.subs(*args, **kwargs)

64
    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
65
66
67
        """
        Returns a set of all children which are an instance of the given argType
        """
68
69
70
71
72
73
74
75
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


117
118
119
class KernelFunction(Node):

    class Argument:
120
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
121
            from pystencils.transformations import symbolNameToVariableName
122
            self.name = name
123
            self.dtype = dtype
124
125
126
127
128
129
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
130
            self.symbol = symbol
131
132
133
134
135
136
137
138
139
140
141
142
143
144

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

145
146
147
148
149
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

167
168
169
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

170
    def __init__(self, body, ghostLayers=None, functionName="kernel"):
171
172
        super(KernelFunction, self).__init__()
        self._body = body
173
        body.parent = self
174
        self._parameters = None
175
        self.functionName = functionName
176
        self._body.parent = self
177
        self.ghostLayers = ghostLayers
178
179
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
180
181
182
183
184
185

    @property
    def symbolsDefined(self):
        return set()

    @property
186
    def undefinedSymbols(self):
187
188
189
190
191
192
193
194
195
196
197
198
199
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
200
        return [self._body]
201

202
203
204
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
205
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
206

207
    def _updateParameters(self):
208
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
209
210
211
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
212

213
    def __str__(self):
214
        self._updateParameters()
215
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
216
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
217
218
219
220

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
238
    def insertBefore(self, newNode, insertBefore):
239
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
240
241
242
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
271
    def undefinedSymbols(self):
272
        result = set()
273
        definedSymbols = set()
274
        for a in self.args:
275
276
277
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
278

279
280
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
281
282

    def __repr__(self):
283
        return ''.join('{!r}'.format(node) for node in self._nodes)
284

285
286
287
288
289
290
291
292
293
294

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
295
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
296
        self.body = body
297
        body.parent = self
298
299
300
301
302
        self.coordinateToLoopOver = coordinateToLoopOver
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
303
304
305
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
306
        result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
307
        result.prefixLines = [l for l in self.prefixLines]
308
309
        return result

310
311
312
313
314
315
316
317
318
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

319
320
    @property
    def args(self):
321
322
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
323
324
            if hasattr(e, "args"):
                result.append(e)
325
326
327
328
        return result

    @property
    def symbolsDefined(self):
329
330
331
332
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
333
334
        result = self.body.undefinedSymbols
        for possibleSymbol in [self.start, self.stop, self.step]:
335
336
337
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
338

Martin Bauer's avatar
Martin Bauer committed
339
340
341
342
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

343
344
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
345
346
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

347
348
349
350
351
352
353
354
355
356
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
357
358
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
359
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
360
361
362

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
363
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
364
365
366

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
367
368
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
369
370
371

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
372
        return len(self.atoms(LoopOverCoordinate)) == 0
373

374
    def __str__(self):
375
        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
Michael Kuron's avatar
Michael Kuron committed
376
                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
377
378

    def __repr__(self):
379
        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
380

381
382
383
384
385
386

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
387
        isCast = self._lhsSymbol.func == castFunc
388
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
389
390
391
392
393
394
395
396
397
398
399
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
400
        isCast = self._lhsSymbol.func == castFunc
401
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
402
403
            self._isDeclaration = False

404
    def subs(self, *args, **kwargs):
405
406
        self.lhs = fastSubs(self.lhs, *args, **kwargs)
        self.rhs = fastSubs(self.rhs, *args, **kwargs)
407

408
409
410
411
412
413
414
415
416
417
418
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
419
    def undefinedSymbols(self):
420
        result = self.rhs.atoms(sp.Symbol)
421
422
423
424
425
426
427
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
428
429
430
431
432
433
434
435
436
437
438
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
439
440
    def replace(self, child, replacement):
        if child == self.lhs:
441
442
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
443
444
445
446
447
448
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

449
450
451
452
453
454
455
456
457
458
459
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
460
        return set([self.symbol])
461
462

    @property
463
464
465
466
467
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
468
469
470

    @property
    def args(self):
471
        return [self.symbol]
472
473
474
475


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
476
        self.symbol = typedSymbol
477
478
479
480
481
482

    @property
    def symbolsDefined(self):
        return set()

    @property
483
    def undefinedSymbols(self):
484
485
486
487
488
489
        return set()

    @property
    def args(self):
        return []

Jan Hoenig's avatar
Jan Hoenig committed
490

Jan Hoenig's avatar
Jan Hoenig committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
# TODO implement defined & undefinedSymbols
class Conversion(Node):
    def __init__(self, child, dtype, parent=None):
        super(Conversion, self).__init__(parent)
        self._args = [child]
        self.dtype = dtype

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
518
        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
Jan Hoenig's avatar
Jan Hoenig committed
519

Jan Hoenig's avatar
Jan Hoenig committed
520
# TODO Pow
Jan Hoenig's avatar
Jan Hoenig committed
521

Jan Hoenig's avatar
Jan Hoenig committed
522
523
524
525

_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}


Jan Hoenig's avatar
Jan Hoenig committed
526
527
528
529
class Expr(Node):
    def __init__(self, args, parent=None):
        super(Expr, self).__init__(parent)
        self._args = list(args)
Jan Hoenig's avatar
Jan Hoenig committed
530
        self.dtype = None
Jan Hoenig's avatar
Jan Hoenig committed
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

    @property
    def args(self):
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    def replace(self, child, replacements):
        idx = self.args.index(child)
        del self.args[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self.args = self.args[:idx] + replacements + self.args[idx:]
        else:
            replacements.parent = self
            self.args.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        return set()  # Todo fix for symbol analysis

    @property
    def undefinedSymbols(self):
        return set()  # Todo fix for symbol analysis

Jan Hoenig's avatar
Jan Hoenig committed
559
    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
560
        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
Jan Hoenig's avatar
Jan Hoenig committed
561

Jan Hoenig's avatar
Jan Hoenig committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575

class Mul(Expr):
    pass


class Add(Expr):
    pass


class Pow(Expr):
    pass


class Indexed(Expr):
576
577
578
    def __init__(self, args, base, parent=None):
        super(Indexed, self).__init__(args, parent)
        self.base = base
Jan Hoenig's avatar
Jan Hoenig committed
579
580
        # Get dtype from label, and unpointer it
        self.dtype = createType(base.label.dtype.baseType)
581

Jan Hoenig's avatar
Jan Hoenig committed
582
583
    def __repr__(self):
        return '%s[%s]' % (self.args[0], self.args[1])
Jan Hoenig's avatar
Jan Hoenig committed
584

585

586
587
588
589
590
591
592
593
594
595
596
class PointerArithmetic(Expr):
    def __init__(self, args, pointer, parent=None):
        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
        self.pointer = pointer
        self.offset = args
        self.dtype = pointer.dtype

    def __repr__(self):
        return '*(%s + %s)' % (self.pointer, self.args)


Jan Hoenig's avatar
Jan Hoenig committed
597
class Number(Node, sp.AtomicExpr):
Jan Hoenig's avatar
Jan Hoenig committed
598
599
    def __init__(self, number, parent=None):
        super(Number, self).__init__(parent)
Jan Hoenig's avatar
Jan Hoenig committed
600
601
602

        self.dtype, self.value = get_type_from_sympy(number)
        self._args = tuple()
Jan Hoenig's avatar
Jan Hoenig committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
620
        return repr(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
621

Jan Hoenig's avatar
Jan Hoenig committed
622
623
624
625
626
    def __float__(self):
        return float(self.value)

    def __int__(self):
        return int(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
627

Jan Hoenig's avatar
Jan Hoenig committed
628