transformations.py 27.7 KB
Newer Older
1
from collections import defaultdict, OrderedDict
Jan Hoenig's avatar
Jan Hoenig committed
2
3
from operator import attrgetter

4
5
6
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
Martin Bauer's avatar
Martin Bauer committed
7

8
from pystencils.field import Field, offsetComponentToDirectionString
9
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType
Martin Bauer's avatar
Martin Bauer committed
10
from pystencils.slicing import normalizeSlice
Martin Bauer's avatar
Martin Bauer committed
11
import pystencils.astnodes as ast
12
13


14
15
16
17
18
19
def fastSubs(term, subsDict):
    """Similar to sympy subs function.
    This version is much faster for big substitution dictionaries than sympy version"""
    def visit(expr):
        if expr in subsDict:
            return subsDict[expr]
Martin Bauer's avatar
Martin Bauer committed
20
21
        if not hasattr(expr, 'args'):
            return expr
22
23
24
25
26
        paramList = [visit(a) for a in expr.args]
        return expr if not paramList else expr.func(*paramList)
    return visit(term)


27
def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None, loopOrder=None):
28
    """
Martin Bauer's avatar
Martin Bauer committed
29
    Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
30

31
32
    :param body: list of nodes
    :param functionName: name of generated C function
Martin Bauer's avatar
Martin Bauer committed
33
    :param iterationSlice: if not None, iteration is done only over this slice of the field
34
    :param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
35
36
                if None, the number of ghost layers is determined automatically and assumed to be equal for a
                all dimensions
37
    :param loopOrder: loop ordering from outer to inner loop (optimal ordering is same as layout)
Martin Bauer's avatar
Martin Bauer committed
38
    :return: :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
39
40
41
42
43
    """
    # find correct ordering by inspecting participating FieldAccesses
    fieldAccesses = body.atoms(Field.Access)
    fieldList = [e.field for e in fieldAccesses]
    fields = set(fieldList)
44
45
46

    if loopOrder is None:
        loopOrder = getOptimalLoopOrdering(fields)
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    nrOfFixedShapedFields = 0
    for f in fields:
        if f.hasFixedShape:
            nrOfFixedShapedFields += 1

    if nrOfFixedShapedFields > 0 and nrOfFixedShapedFields != len(fields):
        fixedFieldNames = ",".join([f.name for f in fields if f.hasFixedShape])
        varFieldNames = ",".join([f.name for f in fields if not f.hasFixedShape])
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (varFieldNames, fixedFieldNames)
        raise ValueError(msg)

    shapeSet = set([f.spatialShape for f in fields])
    if nrOfFixedShapedFields == len(fields):
        if len(shapeSet) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shapeSet))

65
    shape = list(sorted(shapeSet, key=lambda e: str(e[0])))[0]
66

Martin Bauer's avatar
Martin Bauer committed
67
68
69
    if iterationSlice is not None:
        iterationSlice = normalizeSlice(iterationSlice, shape)

70
71
72
73
    if ghostLayers is None:
        requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
        ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(loopOrder)

74
75
    currentBody = body
    lastLoop = None
76
    for i, loopCoordinate in enumerate(reversed(loopOrder)):
Martin Bauer's avatar
Martin Bauer committed
77
        if iterationSlice is None:
78
79
            begin = ghostLayers[loopCoordinate][0]
            end = shape[loopCoordinate] - ghostLayers[loopCoordinate][1]
Martin Bauer's avatar
Martin Bauer committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1)
            lastLoop = newLoop
            currentBody = ast.Block([lastLoop])
        else:
            sliceComponent = iterationSlice[loopCoordinate]
            if type(sliceComponent) is slice:
                sc = sliceComponent
                newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step)
                lastLoop = newLoop
                currentBody = ast.Block([lastLoop])
            else:
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
                                                 sp.sympify(sliceComponent))
                currentBody.insertFront(assignment)
94
    return ast.KernelFunction(currentBody, fields, functionName)
95
96
97


def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Martin Bauer's avatar
Martin Bauer committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    r"""
    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.
    :param fieldAccess: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
    :param previousPtr: the pointer which is dereferenced
    :return: tuple with the new pointer symbol and the calculated offset

    Example:
        >>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
        >>> x, y = sp.symbols("x y")
112
        >>> prevPointer = TypedSymbol("ptr", "double")
Martin Bauer's avatar
Martin Bauer committed
113
114
115
116
117
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
        (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
        (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
    """
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    field = fieldAccess.field

    offset = 0
    name = ""
    listToHash = []
    for coordinateId, coordinateValue in coordinates.items():
        offset += field.strides[coordinateId] * coordinateValue

        if coordinateId < field.spatialDimensions:
            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
            if type(fieldAccess.offsets[coordinateId]) is int:
                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
                name += "_"
                name += offsetComp if offsetComp else "C"
            else:
                listToHash.append(fieldAccess.offsets[coordinateId])
        else:
            if type(coordinateValue) is int:
                name += "_%d" % (coordinateValue,)
            else:
                listToHash.append(coordinateValue)

    if len(listToHash) > 0:
        name += "%0.6X" % (abs(hash(tuple(listToHash))))

143
    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
144
145
146
147
148
    return newPtr, offset


def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
    """
Martin Bauer's avatar
Martin Bauer committed
149
150
151
152
153
154
155
    Creates base pointer specification for :func:`resolveFieldAccesses` function.

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
    This function translates more readable version into the specification above.

156
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
157
158
159
160
161
162
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

163
    :param basePointerSpecification: nested list with above specifications
164
    :param loopOrder: list with ordering of loops from outer to inner
165
    :param field:
Martin Bauer's avatar
Martin Bauer committed
166
    :return: list of tuples that can be passed to :func:`resolveFieldAccesses`
167
168
169
    """
    result = []
    specifiedCoordinates = set()
170
    loopOrder = list(reversed(loopOrder))
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    for specGroup in basePointerSpecification:
        newGroup = []

        def addNewElement(i):
            if i >= field.spatialDimensions + field.indexDimensions:
                raise ValueError("Coordinate %d does not exist" % (i,))
            newGroup.append(i)
            if i in specifiedCoordinates:
                raise ValueError("Coordinate %d specified two times" % (i,))
            specifiedCoordinates.add(i)

        for element in specGroup:
            if type(element) is int:
                addNewElement(element)
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
                    addNewElement(loopOrder[index])
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
                    addNewElement(loopOrder[-index])
                elif element == "all":
                    for i in range(field.spatialDimensions):
                        addNewElement(i)
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
                addNewElement(field.spatialDimensions + index)
            else:
                raise ValueError("Unknown specification %s" % (element,))

        result.append(newGroup)

    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
    rest = allCoordinates - specifiedCoordinates
    if rest:
        result.append(list(rest))
    return result


213
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
Martin Bauer's avatar
Martin Bauer committed
214
215
216
217
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

    :param astNode: the AST root
218
    :param readOnlyFieldNames: set of field names which are considered read-only
Martin Bauer's avatar
Martin Bauer committed
219
220
221
222
223
224
    :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created
                                   for details see :func:`parseBasePointerInfo`
    :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
                                    counters to index the field these symbols are used as coordinates
    :return: transformed AST
    """
225
226
227
    fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0]))
    fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0]))

228
    def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
229
230
231
232
233
234
235
236
        if isinstance(expr, Field.Access):
            fieldAccess = expr
            field = fieldAccess.field
            if field.name in fieldToBasePointerInfo:
                basePointerInfo = fieldToBasePointerInfo[field.name]
            else:
                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]

237
            dtype = PointerType(field.dtype, const=field.name in readOnlyFieldNames, restrict=True)
238
            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
239
240
241
242
243
244
245
246
247
248
249

            lastPointer = fieldPtr

            def createCoordinateDict(group):
                coordDict = {}
                for e in group:
                    if e < field.spatialDimensions:
                        if field.name in fieldToFixedCoordinates:
                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
                        else:
                            ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
250
                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
251
                        coordDict[e] *= field.dtype.itemSize
252
                    else:
253
254
255
256
257
258
259
                        if isinstance(field.dtype, StructType):
                            assert field.indexDimensions == 1
                            accessedFieldName = fieldAccess.index[0]
                            assert isinstance(accessedFieldName, str)
                            coordDict[e] = field.dtype.getElementOffset(accessedFieldName)
                        else:
                            coordDict[e] = fieldAccess.index[e - field.spatialDimensions]
260
261
262
263
264
265
                return coordDict

            for group in reversed(basePointerInfo[1:]):
                coordDict = createCoordinateDict(group)
                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
                if newPtr not in enclosingBlock.symbolsDefined:
266
267
                    newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)
                    enclosingBlock.insertBefore(newAssignment, sympyAssignment)
268
269
                lastPointer = newPtr

270
271
            coordDict = createCoordinateDict(basePointerInfo[0])
            _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
272
            baseArr = IndexedBase(lastPointer, shape=(1,))
273
            result = baseArr[offset]
274
            castFunc = sp.Function("cast")
275
276
            if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
                newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
277
278
279
                result = castFunc(result, newType)

            return visitSympyExpr(result, enclosingBlock, sympyAssignment)
280
        else:
281
            newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
282
283

            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
284
285
286
287
288
289
            return expr.func(*newArgs, **kwargs) if newArgs else expr

    def visitNode(subAst):
        if isinstance(subAst, ast.SympyAssignment):
            enclosingBlock = subAst.parent
            assert type(enclosingBlock) is ast.Block
290
291
            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
292
293
294
295
296
297
298
299
        else:
            for i, a in enumerate(subAst.args):
                visitNode(a)

    return visitNode(astNode)


def moveConstantsBeforeLoop(astNode):
Martin Bauer's avatar
Martin Bauer committed
300
301
302
303
304
305
    """
    Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
    Call this after creating the loop structure with :func:`makeLoopOverDomain`
    :param astNode:
    :return:
    """
306
    def findBlockToMoveTo(node):
Martin Bauer's avatar
Martin Bauer committed
307
308
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
309
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
310
311
312
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
313
314
315
316
        assert isinstance(node, ast.SympyAssignment)
        assert isinstance(node.parent, ast.Block)

        lastBlock = node.parent
Martin Bauer's avatar
Martin Bauer committed
317
        lastBlockChild = node
318
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
319
        prevElement = node
320
321
322
        while element:
            if isinstance(element, ast.Block):
                lastBlock = element
Martin Bauer's avatar
Martin Bauer committed
323
                lastBlockChild = prevElement
324
            if node.undefinedSymbols.intersection(element.symbolsDefined):
325
                break
Martin Bauer's avatar
Martin Bauer committed
326
            prevElement = element
327
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
328
        return lastBlock, lastBlockChild
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
        for arg in targetBlock.args:
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

    for block in astNode.atoms(ast.Block):
        children = block.takeChildNodes()
        for child in children:
            if not isinstance(child, ast.SympyAssignment):
                block.append(child)
            else:
Martin Bauer's avatar
Martin Bauer committed
344
                target, childToInsertBefore = findBlockToMoveTo(child)
345
346
347
348
349
                if target == block:     # movement not possible
                    target.append(child)
                else:
                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
                    if not existingAssignment:
Martin Bauer's avatar
Martin Bauer committed
350
                        target.insertBefore(child, childToInsertBefore)
351
352
353
354
355
                    else:
                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"


def splitInnerLoop(astNode, symbolGroups):
Martin Bauer's avatar
Martin Bauer committed
356
357
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
358

Martin Bauer's avatar
Martin Bauer committed
359
360
    :param astNode: AST root
    :param symbolGroups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
361
362
         updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups and which
         no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
363
364
    :return: transformed AST
    """
365
366
367
368
369
370
371
372
373
    allLoops = astNode.atoms(ast.LoopOverCoordinate)
    innerLoop = [l for l in allLoops if l.isInnermostLoop]
    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    innerLoop = innerLoop[0]
    assert type(innerLoop.body) is ast.Block
    outerLoop = [l for l in allLoops if l.isOutermostLoop]
    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
    outerLoop = outerLoop[0]

374
375
    symbolsWithTemporaryArray = OrderedDict()
    assignmentMap = OrderedDict((a.lhs, a) for a in innerLoop.body.args)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

    assignmentGroups = []
    for symbolGroup in symbolGroups:
        # get all dependent symbols
        symbolsToProcess = list(symbolGroup)
        symbolsResolved = set()
        while symbolsToProcess:
            s = symbolsToProcess.pop()
            if s in symbolsResolved:
                continue

            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
                        symbolsToProcess.append(newSymbol)
            symbolsResolved.add(s)

        for symbol in symbolGroup:
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
                symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]

        assignmentGroup = []
        for assignment in innerLoop.body.args:
            if assignment.lhs in symbolsResolved:
                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
                    newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
                else:
                    newLhs = assignment.lhs
                assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
        assignmentGroups.append(assignmentGroup)

    newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
410
    innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
411
412

    for tmpArray in symbolsWithTemporaryArray:
Martin Bauer's avatar
Martin Bauer committed
413
        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
414
415
416
        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))


417
418
419
420
421
def symbolNameToVariableName(symbolName):
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
    return symbolName.replace("^", "_")


422
def typeAllEquations(eqs, typeForSymbol):
Martin Bauer's avatar
Martin Bauer committed
423
424
425
426
427
428
429
430
431
    """
    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
    Additionally returns sets of all fields which are read/written

    :param eqs: list of equations
    :param typeForSymbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
    :return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations
               where symbols have been replaced by typed symbols
    """
432
433
434
435
436
    if not typeForSymbol or typeForSymbol == 'double':
        typeForSymbol = typingFromSympyInspection(eqs, "double")
    elif typeForSymbol == 'float':
        typeForSymbol = typingFromSympyInspection(eqs, "float")

437
438
439
440
441
442
443
444
445
446
    fieldsWritten = set()
    fieldsRead = set()

    def processRhs(term):
        """Replaces Symbols by:
            - TypedSymbol if symbol is not a field access
        """
        if isinstance(term, Field.Access):
            fieldsRead.add(term.field)
            return term
447
448
        elif isinstance(term, TypedSymbol):
            return term
449
        elif isinstance(term, sp.Symbol):
450
            return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
451
452
453
454
455
456
457
458
459
        else:
            newArgs = [processRhs(arg) for arg in term.args]
            return term.func(*newArgs) if newArgs else term

    def processLhs(term):
        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
        if isinstance(term, Field.Access):
            fieldsWritten.add(term.field)
            return term
460
461
        elif isinstance(term, TypedSymbol):
            return term
462
        elif isinstance(term, sp.Symbol):
463
            return TypedSymbol(term.name, typeForSymbol[term.name])
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        else:
            assert False, "Expected a symbol as left-hand-side"

    typedEquations = []
    for eq in eqs:
        if isinstance(eq, sp.Eq):
            newLhs = processLhs(eq.lhs)
            newRhs = processRhs(eq.rhs)
            typedEquations.append(ast.SympyAssignment(newLhs, newRhs))
        else:
            assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input"
            typedEquations.append(eq)

    typedEquations = typedEquations

    return fieldsRead, fieldsWritten, typedEquations


Martin Bauer's avatar
Martin Bauer committed
482
483
484
# --------------------------------------- Helper Functions -------------------------------------------------------------


485
def typingFromSympyInspection(eqs, defaultType="double"):
Martin Bauer's avatar
Martin Bauer committed
486
487
488
489
490
491
492
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
    :param eqs: list of equations
    :param defaultType: the type for non-boolean symbols
    :return: dictionary, mapping symbol name to type
    """
493
494
    result = defaultdict(lambda: defaultType)
    for eq in eqs:
495
496
        if isinstance(eq, ast.Node):
            continue
497
498
499
        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
        # further information what type the left hand side is - default fallback is the dict value then
        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
500
501
502
503
504
            result[eq.lhs.name] = "bool"
    return result


def getNextParentOfType(node, parentType):
Martin Bauer's avatar
Martin Bauer committed
505
506
507
    """
    Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
    """
508
509
510
511
512
513
514
515
516
    parent = node.parent
    while parent is not None:
        if isinstance(parent, parentType):
            return parent
        parent = parent.parent
    return None


def getOptimalLoopOrdering(fields):
Martin Bauer's avatar
Martin Bauer committed
517
518
519
520
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
    :param fields: sequence of fields
521
    :return: list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
522
    """
523
524
525
526
527
528
529
530
    assert len(fields) > 0
    refField = next(iter(fields))
    for field in fields:
        if field.spatialDimensions != refField.spatialDimensions:
            raise ValueError("All fields have to have the same number of spatial dimensions")

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
531
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " + str(layouts))
532
    layout = list(layouts)[0]
533
    return list(layout)
534
535


Martin Bauer's avatar
Martin Bauer committed
536
537
538
539
540
def getLoopHierarchy(astNode):
    """Determines the loop structure around a given AST node.
    :param astNode: the AST node
    :return: list of coordinate ids, where the first list entry is the innermost loop
    """
541
    result = []
Martin Bauer's avatar
Martin Bauer committed
542
    node = astNode
543
544
545
546
    while node is not None:
        node = getNextParentOfType(node, ast.LoopOverCoordinate)
        if node:
            result.append(node.coordinateToLoopOver)
547
548
    return reversed(result)

Jan Hoenig's avatar
Jan Hoenig committed
549

Jan Hoenig's avatar
Jan Hoenig committed
550
551
552
553
554
555
556
557
def get_type(node):
    if isinstance(node, ast.Indexed):
        return node.args[0].dtype
    elif isinstance(node, ast.Node):
        return node.dtype
    # TODO sp.NumberSymbol
    elif isinstance(node, sp.Number):
        if isinstance(node, sp.Float):
558
            return createType('double')
Jan Hoenig's avatar
Jan Hoenig committed
559
        elif isinstance(node, sp.Integer):
560
            return createType('int')
Jan Hoenig's avatar
Jan Hoenig committed
561
562
563
564
565
566
        else:
            raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
    else:
        raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))


Jan Hoenig's avatar
Jan Hoenig committed
567
def insert_casts(node):
Jan Hoenig's avatar
Jan Hoenig committed
568
    """
569
    Inserts casts and dtype where needed
Jan Hoenig's avatar
Jan Hoenig committed
570
571
572
    :param node: ast which should be traversed
    :return: node
    """
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    def conversion(args):
        target = args[0]
        if isinstance(target.dtype, PointerType):
            # Pointer arithmetic
            for arg in args[1:]:
                # Check validness
                if not arg.dtype.is_int() and not arg.dtype.is_uint():
                    raise ValueError("Impossible pointer arithmetic", target, arg)
            pointer = ast.PointerArithmetic(ast.Add(args[1:]), target)
            return [pointer]

        else:
            for i in range(len(args)):
                if args[i].dtype != target.dtype:
                    args[i] = ast.Conversion(args[i], target.dtype, node)
            return args
Jan Hoenig's avatar
Jan Hoenig committed
589
590
591
592

    for arg in node.args:
        insert_casts(arg)
    if isinstance(node, ast.Indexed):
Jan Hoenig's avatar
Jan Hoenig committed
593
        #TODO revmove this
Jan Hoenig's avatar
Jan Hoenig committed
594
        pass
Jan Hoenig's avatar
Jan Hoenig committed
595
    elif isinstance(node, ast.Expr):
596
597
598
        #print(node, node.args)
        #print([type(arg) for arg in node.args])
        #print([arg.dtype for arg in node.args])
Jan Hoenig's avatar
Jan Hoenig committed
599
        args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
Jan Hoenig's avatar
Jan Hoenig committed
600
        target = args[0]
601
        node.args = conversion(args)
Jan Hoenig's avatar
Jan Hoenig committed
602
        node.dtype = target.dtype
603
604
        #print(node.dtype)
        #print(node)
Jan Hoenig's avatar
Jan Hoenig committed
605
606
607
    elif isinstance(node, ast.SympyAssignment):
        if node.lhs.dtype != node.rhs.dtype:
            node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
Jan Hoenig's avatar
Jan Hoenig committed
608
    elif isinstance(node, ast.LoopOverCoordinate):
Jan Hoenig's avatar
Jan Hoenig committed
609
        pass
Jan Hoenig's avatar
Jan Hoenig committed
610
    return node
Jan Hoenig's avatar
Jan Hoenig committed
611
612
613


def desympy_ast(node):
Jan Hoenig's avatar
Jan Hoenig committed
614
615
616
617
618
619
    """
    Remove Sympy Expressions, which have more then one argument.
    This is necessary for further changes in the tree.
    :param node: ast which should be traversed. Only node's children will be modified.
    :return: (modified) node
    """
Jan Hoenig's avatar
Jan Hoenig committed
620
621
    if node.args is None:
        return node
Jan Hoenig's avatar
Jan Hoenig committed
622
623
624
625
    for i in range(len(node.args)):
        arg = node.args[i]
        if isinstance(arg, sp.Add):
            node.replace(arg, ast.Add(arg.args, node))
Jan Hoenig's avatar
Jan Hoenig committed
626
627
        elif isinstance(arg, sp.Number):
            node.replace(arg, ast.Number(arg, node))
Jan Hoenig's avatar
Jan Hoenig committed
628
629
630
631
        elif isinstance(arg, sp.Mul):
            node.replace(arg, ast.Mul(arg.args, node))
        elif isinstance(arg, sp.Pow):
            node.replace(arg, ast.Pow(arg.args, node))
632
633
634
635
636
637
638
        elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed):
            node.replace(arg, ast.Indexed(arg.args, arg.base, node))
        elif isinstance(arg,  sp.tensor.IndexedBase):
            node.replace(arg, arg.label)
        #elif isinstance(arg, sp.containers.Tuple):
        #
        else:
639
640
            #print('Not transforming:', type(arg), arg)
            pass
Jan Hoenig's avatar
Jan Hoenig committed
641
642
    for arg in node.args:
        desympy_ast(arg)
Jan Hoenig's avatar
Jan Hoenig committed
643
    return node
644
645
646
647
648
649
650
651
652
653
654
655


def check_dtype(node):
    if isinstance(node, ast.KernelFunction):
        pass
    elif isinstance(node, ast.Block):
        pass
    elif isinstance(node, ast.LoopOverCoordinate):
        pass
    elif isinstance(node, ast.SympyAssignment):
        pass
    else:
656
657
658
        #print(node)
        #print(node.dtype)
        pass
659
660
661
    for arg in node.args:
        check_dtype(arg)