astnodes.py 18.6 KB
Newer Older
1
import sympy as sp
2
from sympy.tensor import IndexedBase
3
from pystencils.field import Field
4
from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
5
6


Martin Bauer's avatar
Martin Bauer committed
7
8
class ResolvedFieldAccess(sp.Indexed):
    def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
9
10
        if not isinstance(base, IndexedBase):
            base = IndexedBase(base, shape=(1,))
Martin Bauer's avatar
Martin Bauer committed
11
12
13
14
15
16
        obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
        obj.field = field
        obj.offsets = offsets
        obj.idxCoordinateValues = idxCoordinateValues
        return obj

17
18
19
20
21
    def _eval_subs(self, old, new):
        return ResolvedFieldAccess(self.args[0],
                                   self.args[1].subs(old, new),
                                   self.field, self.offsets, self.idxCoordinateValues)

Martin Bauer's avatar
Martin Bauer committed
22
23
24
25
    def _hashable_content(self):
        superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
        return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))

26
27
28
29
30
31
32
33
    @property
    def typedSymbol(self):
        return self.base.label

    def __str__(self):
        top = super(ResolvedFieldAccess, self).__str__()
        return "%s (%s)" % (top, self.typedSymbol.dtype)

Martin Bauer's avatar
Martin Bauer committed
34
    def __getnewargs__(self):
35
        return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
Martin Bauer's avatar
Martin Bauer committed
36
37


38
class Node(object):
Martin Bauer's avatar
Martin Bauer committed
39
40
    """Base class for all AST nodes"""

41
42
43
44
    def __init__(self, parent=None):
        self.parent = parent

    def args(self):
Martin Bauer's avatar
Martin Bauer committed
45
        """Returns all arguments/children of this node"""
46
47
48
49
        return []

    @property
    def symbolsDefined(self):
50
        """Set of symbols which are defined by this node. """
51
52
53
        return set()

    @property
54
    def undefinedSymbols(self):
55
        """Symbols which are used but are not defined inside this node"""
56
        raise NotImplementedError()
57

58
59
60
61
62
    def subs(self, *args, **kwargs):
        """Inplace! substitute, similar to sympys but modifies ast and returns None"""
        for a in self.args:
            a.subs(*args, **kwargs)

63
    def atoms(self, argType):
Martin Bauer's avatar
Martin Bauer committed
64
65
66
        """
        Returns a set of all children which are an instance of the given argType
        """
67
68
69
70
71
72
73
74
        result = set()
        for arg in self.args:
            if isinstance(arg, argType):
                result.add(arg)
            result.update(arg.atoms(argType))
        return result


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Conditional(Node):
    """Conditional"""
    def __init__(self, conditionExpr, trueBlock, falseBlock=None):
        """
        Create a new conditional node

        :param conditionExpr: sympy relational expression
        :param trueBlock: block which is run if conditional is true
        :param falseBlock: block which is run if conditional is false, or None if not needed
        """
        assert conditionExpr.is_Boolean or conditionExpr.is_Relational
        self.conditionExpr = conditionExpr
        self.trueBlock = trueBlock
        self.falseBlock = falseBlock

    @property
    def args(self):
        result = [self.conditionExpr, self.trueBlock]
        if self.falseBlock:
            result.append(self.falseBlock)
        return result

    @property
    def symbolsDefined(self):
        return set()

    @property
    def undefinedSymbols(self):
        result = self.trueBlock.undefinedSymbols
        if self.falseBlock:
            result = result.update(self.falseBlock.undefinedSymbols)
        result.update(self.conditionExpr.atoms(sp.Symbol))
        return result

    def __str__(self):
        return 'if:({!s}) '.format(self.conditionExpr)

    def __repr__(self):
        return 'if:({!r}) '.format(self.conditionExpr)


116
117
118
class KernelFunction(Node):

    class Argument:
119
        def __init__(self, name, dtype, symbol, kernelFunctionNode):
120
            from pystencils.transformations import symbolNameToVariableName
121
            self.name = name
122
            self.dtype = dtype
123
124
125
126
127
128
            self.isFieldPtrArgument = False
            self.isFieldShapeArgument = False
            self.isFieldStrideArgument = False
            self.isFieldArgument = False
            self.fieldName = ""
            self.coordinate = None
129
            self.symbol = symbol
130
131
132
133
134
135
136
137
138
139
140
141
142
143

            if name.startswith(Field.DATA_PREFIX):
                self.isFieldPtrArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.DATA_PREFIX):]
            elif name.startswith(Field.SHAPE_PREFIX):
                self.isFieldShapeArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.SHAPE_PREFIX):]
            elif name.startswith(Field.STRIDE_PREFIX):
                self.isFieldStrideArgument = True
                self.isFieldArgument = True
                self.fieldName = name[len(Field.STRIDE_PREFIX):]

144
145
146
147
148
            self.field = None
            if self.isFieldArgument:
                fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
                self.field = fieldMap[self.fieldName]

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        def __lt__(self, other):
            def score(l):
                if l.isFieldPtrArgument:
                    return -4
                elif l.isFieldShapeArgument:
                    return -3
                elif l.isFieldStrideArgument:
                    return -2
                return 0

            if score(self) < score(other):
                return True
            elif score(self) == score(other):
                return self.name < other.name
            else:
                return False

166
167
168
        def __repr__(self):
            return '<{0} {1}>'.format(self.dtype, self.name)

169
    def __init__(self, body, ghostLayers=None, functionName="kernel"):
170
171
        super(KernelFunction, self).__init__()
        self._body = body
172
        body.parent = self
173
        self._parameters = None
174
        self.functionName = functionName
175
        self._body.parent = self
176
        self.ghostLayers = ghostLayers
177
178
        # these variables are assumed to be global, so no automatic parameter is generated for them
        self.globalVariables = set()
179
180
181
182
183
184

    @property
    def symbolsDefined(self):
        return set()

    @property
185
    def undefinedSymbols(self):
186
187
188
189
190
191
192
193
194
195
196
197
198
        return set()

    @property
    def parameters(self):
        self._updateParameters()
        return self._parameters

    @property
    def body(self):
        return self._body

    @property
    def args(self):
Jan Hoenig's avatar
Jan Hoenig committed
199
        return [self._body]
200

201
202
203
    @property
    def fieldsAccessed(self):
        """Set of Field instances: fields which are accessed inside this kernel function"""
Martin Bauer's avatar
Martin Bauer committed
204
        return set(o.field for o in self.atoms(ResolvedFieldAccess))
205

206
    def _updateParameters(self):
207
        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
208
209
210
        self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]

        self._parameters.sort()
211

212
    def __str__(self):
213
        self._updateParameters()
214
        return '{0} {1}({2})\n{3}'.format(type(self).__name__, self.functionName, self.parameters,
Michael Kuron's avatar
Michael Kuron committed
215
                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
216
217
218
219

    def __repr__(self):
        self._updateParameters()
        return '{0} {1}({2})'.format(type(self).__name__, self.functionName, self.parameters)
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

class Block(Node):
    def __init__(self, listOfNodes):
        super(Node, self).__init__()
        self._nodes = listOfNodes
        for n in self._nodes:
            n.parent = self

    @property
    def args(self):
        return self._nodes

    def insertFront(self, node):
        node.parent = self
        self._nodes.insert(0, node)

Martin Bauer's avatar
Martin Bauer committed
237
    def insertBefore(self, newNode, insertBefore):
238
        newNode.parent = self
Martin Bauer's avatar
Martin Bauer committed
239
240
241
        idx = self._nodes.index(insertBefore)
        self._nodes.insert(idx, newNode)

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def append(self, node):
        node.parent = self
        self._nodes.append(node)

    def takeChildNodes(self):
        tmp = self._nodes
        self._nodes = []
        return tmp

    def replace(self, child, replacements):
        idx = self._nodes.index(child)
        del self._nodes[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self._nodes = self._nodes[:idx] + replacements + self._nodes[idx:]
        else:
            replacements.parent = self
            self._nodes.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        result = set()
        for a in self.args:
            result.update(a.symbolsDefined)
        return result

    @property
270
    def undefinedSymbols(self):
271
        result = set()
272
        definedSymbols = set()
273
        for a in self.args:
274
275
276
            result.update(a.undefinedSymbols)
            definedSymbols.update(a.symbolsDefined)
        return result - definedSymbols
277

278
279
    def __str__(self):
        return ''.join('{!s}\n'.format(node) for node in self._nodes)
280
281

    def __repr__(self):
282
        return ''.join('{!r}'.format(node) for node in self._nodes)
283

284
285
286
287
288
289
290
291
292
293

class PragmaBlock(Block):
    def __init__(self, pragmaLine, listOfNodes):
        super(PragmaBlock, self).__init__(listOfNodes)
        self.pragmaLine = pragmaLine


class LoopOverCoordinate(Node):
    LOOP_COUNTER_NAME_PREFIX = "ctr"

Martin Bauer's avatar
Martin Bauer committed
294
    def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
295
        self.body = body
296
        body.parent = self
297
298
299
300
301
        self.coordinateToLoopOver = coordinateToLoopOver
        self.start = start
        self.stop = stop
        self.step = step
        self.body.parent = self
302
303
304
        self.prefixLines = []

    def newLoopWithDifferentBody(self, newBody):
305
        result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
Martin Bauer's avatar
Martin Bauer committed
306
        result.prefixLines = [l for l in self.prefixLines]
307
308
        return result

309
310
311
312
313
314
315
316
317
    def subs(self, *args, **kwargs):
        self.body.subs(*args, **kwargs)
        if hasattr(self.start, "subs"):
            self.start = self.start.subs(*args, **kwargs)
        if hasattr(self.stop, "subs"):
            self.stop = self.stop.subs(*args, **kwargs)
        if hasattr(self.step, "subs"):
            self.step = self.step.subs(*args, **kwargs)

318
319
    @property
    def args(self):
320
321
        result = [self.body]
        for e in [self.start, self.stop, self.step]:
Martin Bauer's avatar
Martin Bauer committed
322
323
            if hasattr(e, "args"):
                result.append(e)
324
325
326
327
        return result

    @property
    def symbolsDefined(self):
328
329
330
331
        return set([self.loopCounterSymbol])

    @property
    def undefinedSymbols(self):
332
333
        result = self.body.undefinedSymbols
        for possibleSymbol in [self.start, self.stop, self.step]:
334
335
336
            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
                result.update(possibleSymbol.atoms(sp.Symbol))
        return result - set([self.loopCounterSymbol])
337

Martin Bauer's avatar
Martin Bauer committed
338
339
340
341
    @staticmethod
    def getLoopCounterName(coordinateToLoopOver):
        return "%s_%s" % (LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX, coordinateToLoopOver)

342
343
    @property
    def loopCounterName(self):
Martin Bauer's avatar
Martin Bauer committed
344
345
        return LoopOverCoordinate.getLoopCounterName(self.coordinateToLoopOver)

346
347
348
349
350
351
352
353
354
355
    @staticmethod
    def isLoopCounterSymbol(symbol):
        prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
        if not symbol.name.startswith(prefix):
            return None
        if symbol.dtype != createTypeFromString('int'):
            return None
        coordinate = int(symbol.name[len(prefix)+1:])
        return coordinate

Martin Bauer's avatar
Martin Bauer committed
356
357
    @staticmethod
    def getLoopCounterSymbol(coordinateToLoopOver):
358
        return TypedSymbol(LoopOverCoordinate.getLoopCounterName(coordinateToLoopOver), 'int')
359
360
361

    @property
    def loopCounterSymbol(self):
Martin Bauer's avatar
Martin Bauer committed
362
        return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
363
364
365

    @property
    def isOutermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
366
367
        from pystencils.transformations import getNextParentOfType
        return getNextParentOfType(self, LoopOverCoordinate) is None
368
369
370

    @property
    def isInnermostLoop(self):
Martin Bauer's avatar
Martin Bauer committed
371
        return len(self.atoms(LoopOverCoordinate)) == 0
372

373
    def __str__(self):
374
        return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
Michael Kuron's avatar
Michael Kuron committed
375
                                                          ("\t" + "\t".join(str(self.body).splitlines(True))))
376
377

    def __repr__(self):
378
        return 'loop:{!s} in {!s}:{!s}:{!s}'.format(self.loopCounterName, self.start, self.stop, self.step)
379

380
381
382
383
384
385

class SympyAssignment(Node):
    def __init__(self, lhsSymbol, rhsTerm, isConst=True):
        self._lhsSymbol = lhsSymbol
        self.rhs = rhsTerm
        self._isDeclaration = True
386
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
387
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
388
389
390
391
392
393
394
395
396
397
398
            self._isDeclaration = False
        self._isConst = isConst

    @property
    def lhs(self):
        return self._lhsSymbol

    @lhs.setter
    def lhs(self, newValue):
        self._lhsSymbol = newValue
        self._isDeclaration = True
399
        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
400
        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
401
402
            self._isDeclaration = False

403
404
405
406
    def subs(self, *args, **kwargs):
        self.lhs = self.lhs.subs(*args, **kwargs)
        self.rhs = self.rhs.subs(*args, **kwargs)

407
408
409
410
411
412
413
414
415
416
417
    @property
    def args(self):
        return [self._lhsSymbol, self.rhs]

    @property
    def symbolsDefined(self):
        if not self._isDeclaration:
            return set()
        return set([self._lhsSymbol])

    @property
418
    def undefinedSymbols(self):
419
        result = self.rhs.atoms(sp.Symbol)
420
421
422
423
424
425
426
        # Add loop counters if there a field accesses
        loopCounters = set()
        for symbol in result:
            if isinstance(symbol, Field.Access):
                for i in range(len(symbol.offsets)):
                    loopCounters.add(LoopOverCoordinate.getLoopCounterSymbol(i))
        result.update(loopCounters)
427
428
429
430
431
432
433
434
435
436
437
        result.update(self._lhsSymbol.atoms(sp.Symbol))
        return result

    @property
    def isDeclaration(self):
        return self._isDeclaration

    @property
    def isConst(self):
        return self._isConst

Jan Hoenig's avatar
Jan Hoenig committed
438
439
    def replace(self, child, replacement):
        if child == self.lhs:
440
441
            replacement.parent = self
            self.lhs = replacement
Jan Hoenig's avatar
Jan Hoenig committed
442
443
444
445
446
447
        elif child == self.rhs:
            replacement.parent = self
            self.rhs = replacement
        else:
            raise ValueError('%s is not in args of %s' % (replacement, self.__class__))

448
449
450
451
452
453
454
455
456
457
458
    def __repr__(self):
        return repr(self.lhs) + " = " + repr(self.rhs)


class TemporaryMemoryAllocation(Node):
    def __init__(self, typedSymbol, size):
        self.symbol = typedSymbol
        self.size = size

    @property
    def symbolsDefined(self):
459
        return set([self.symbol])
460
461

    @property
462
463
464
465
466
    def undefinedSymbols(self):
        if isinstance(self.size, sp.Basic):
            return self.size.atoms(sp.Symbol)
        else:
            return set()
467
468
469

    @property
    def args(self):
470
        return [self.symbol]
471
472
473
474


class TemporaryMemoryFree(Node):
    def __init__(self, typedSymbol):
475
        self.symbol = typedSymbol
476
477
478
479
480
481

    @property
    def symbolsDefined(self):
        return set()

    @property
482
    def undefinedSymbols(self):
483
484
485
486
487
488
        return set()

    @property
    def args(self):
        return []

Jan Hoenig's avatar
Jan Hoenig committed
489

Jan Hoenig's avatar
Jan Hoenig committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
# TODO implement defined & undefinedSymbols
class Conversion(Node):
    def __init__(self, child, dtype, parent=None):
        super(Conversion, self).__init__(parent)
        self._args = [child]
        self.dtype = dtype

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
517
        return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
Jan Hoenig's avatar
Jan Hoenig committed
518

Jan Hoenig's avatar
Jan Hoenig committed
519
# TODO Pow
Jan Hoenig's avatar
Jan Hoenig committed
520

Jan Hoenig's avatar
Jan Hoenig committed
521
522
523
524

_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}


Jan Hoenig's avatar
Jan Hoenig committed
525
526
527
528
class Expr(Node):
    def __init__(self, args, parent=None):
        super(Expr, self).__init__(parent)
        self._args = list(args)
Jan Hoenig's avatar
Jan Hoenig committed
529
        self.dtype = None
Jan Hoenig's avatar
Jan Hoenig committed
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557

    @property
    def args(self):
        return self._args

    @args.setter
    def args(self, value):
        self._args = value

    def replace(self, child, replacements):
        idx = self.args.index(child)
        del self.args[idx]
        if type(replacements) is list:
            for e in replacements:
                e.parent = self
            self.args = self.args[:idx] + replacements + self.args[idx:]
        else:
            replacements.parent = self
            self.args.insert(idx, replacements)

    @property
    def symbolsDefined(self):
        return set()  # Todo fix for symbol analysis

    @property
    def undefinedSymbols(self):
        return set()  # Todo fix for symbol analysis

Jan Hoenig's avatar
Jan Hoenig committed
558
    def __repr__(self):
Jan Hoenig's avatar
Jan Hoenig committed
559
        return _expr_dict[self.__class__.__name__].join(repr(arg) for arg in self.args)
Jan Hoenig's avatar
Jan Hoenig committed
560

Jan Hoenig's avatar
Jan Hoenig committed
561
562
563
564
565
566
567
568
569
570
571
572
573
574

class Mul(Expr):
    pass


class Add(Expr):
    pass


class Pow(Expr):
    pass


class Indexed(Expr):
575
576
577
    def __init__(self, args, base, parent=None):
        super(Indexed, self).__init__(args, parent)
        self.base = base
Jan Hoenig's avatar
Jan Hoenig committed
578
579
        # Get dtype from label, and unpointer it
        self.dtype = createType(base.label.dtype.baseType)
580

Jan Hoenig's avatar
Jan Hoenig committed
581
582
    def __repr__(self):
        return '%s[%s]' % (self.args[0], self.args[1])
Jan Hoenig's avatar
Jan Hoenig committed
583

584

585
586
587
588
589
590
591
592
593
594
595
class PointerArithmetic(Expr):
    def __init__(self, args, pointer, parent=None):
        super(PointerArithmetic, self).__init__([args] + [pointer], parent)
        self.pointer = pointer
        self.offset = args
        self.dtype = pointer.dtype

    def __repr__(self):
        return '*(%s + %s)' % (self.pointer, self.args)


Jan Hoenig's avatar
Jan Hoenig committed
596
class Number(Node, sp.AtomicExpr):
Jan Hoenig's avatar
Jan Hoenig committed
597
598
    def __init__(self, number, parent=None):
        super(Number, self).__init__(parent)
Jan Hoenig's avatar
Jan Hoenig committed
599
600
601

        self.dtype, self.value = get_type_from_sympy(number)
        self._args = tuple()
Jan Hoenig's avatar
Jan Hoenig committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

    @property
    def args(self):
        """Returns all arguments/children of this node"""
        return self._args

    @property
    def symbolsDefined(self):
        """Set of symbols which are defined by this node. """
        return set()

    @property
    def undefinedSymbols(self):
        """Symbols which are use but are not defined inside this node"""
        raise set()

    def __repr__(self):
619
        return repr(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
620

Jan Hoenig's avatar
Jan Hoenig committed
621
622
623
624
625
    def __float__(self):
        return float(self.value)

    def __int__(self):
        return int(self.value)
Jan Hoenig's avatar
Jan Hoenig committed
626

Jan Hoenig's avatar
Jan Hoenig committed
627