transformations.py 31.6 KB
Newer Older
1
from collections import defaultdict, OrderedDict
Jan Hoenig's avatar
Jan Hoenig committed
2
from operator import attrgetter
3
from copy import deepcopy
Jan Hoenig's avatar
Jan Hoenig committed
4

5
6
import sympy as sp
from sympy.logic.boolalg import Boolean
7
from sympy.tensor import IndexedBase
Martin Bauer's avatar
Martin Bauer committed
8

9
from pystencils.field import Field, offsetComponentToDirectionString
10
from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc
Martin Bauer's avatar
Martin Bauer committed
11
from pystencils.slicing import normalizeSlice
Martin Bauer's avatar
Martin Bauer committed
12
import pystencils.astnodes as ast
13
14


15
16
17
18
19
20
21
def filteredTreeIteration(node, nodeType):
    for arg in node.args:
        if isinstance(arg, nodeType):
            yield arg
        yield from filteredTreeIteration(arg, nodeType)


22
23
24
25
26
27
def fastSubs(term, subsDict):
    """Similar to sympy subs function.
    This version is much faster for big substitution dictionaries than sympy version"""
    def visit(expr):
        if expr in subsDict:
            return subsDict[expr]
Martin Bauer's avatar
Martin Bauer committed
28
29
        if not hasattr(expr, 'args'):
            return expr
30
31
32
33
34
        paramList = [visit(a) for a in expr.args]
        return expr if not paramList else expr.func(*paramList)
    return visit(term)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def getCommonShape(fieldSet):
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
    nrOfFixedShapedFields = 0
    for f in fieldSet:
        if f.hasFixedShape:
            nrOfFixedShapedFields += 1

    if nrOfFixedShapedFields > 0 and nrOfFixedShapedFields != len(fieldSet):
        fixedFieldNames = ",".join([f.name for f in fieldSet if f.hasFixedShape])
        varFieldNames = ",".join([f.name for f in fieldSet if not f.hasFixedShape])
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (varFieldNames, fixedFieldNames)
        raise ValueError(msg)

    shapeSet = set([f.spatialShape for f in fieldSet])
    if nrOfFixedShapedFields == len(fieldSet):
        if len(shapeSet) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shapeSet))

    shape = list(sorted(shapeSet, key=lambda e: str(e[0])))[0]
    return shape


59
def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None, loopOrder=None):
60
    """
Martin Bauer's avatar
Martin Bauer committed
61
    Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
62

63
64
    :param body: list of nodes
    :param functionName: name of generated C function
Martin Bauer's avatar
Martin Bauer committed
65
    :param iterationSlice: if not None, iteration is done only over this slice of the field
66
    :param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
67
68
                if None, the number of ghost layers is determined automatically and assumed to be equal for a
                all dimensions
69
    :param loopOrder: loop ordering from outer to inner loop (optimal ordering is same as layout)
Martin Bauer's avatar
Martin Bauer committed
70
    :return: :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
71
72
73
74
75
    """
    # find correct ordering by inspecting participating FieldAccesses
    fieldAccesses = body.atoms(Field.Access)
    fieldList = [e.field for e in fieldAccesses]
    fields = set(fieldList)
76
77
78

    if loopOrder is None:
        loopOrder = getOptimalLoopOrdering(fields)
79

80
    shape = getCommonShape(fields)
81

Martin Bauer's avatar
Martin Bauer committed
82
83
84
    if iterationSlice is not None:
        iterationSlice = normalizeSlice(iterationSlice, shape)

85
86
87
    if ghostLayers is None:
        requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
        ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(loopOrder)
88
89
    if isinstance(ghostLayers, int):
        ghostLayers = [(ghostLayers, ghostLayers)] * len(loopOrder)
90

91
92
    currentBody = body
    lastLoop = None
93
    for i, loopCoordinate in enumerate(reversed(loopOrder)):
Martin Bauer's avatar
Martin Bauer committed
94
        if iterationSlice is None:
95
96
            begin = ghostLayers[loopCoordinate][0]
            end = shape[loopCoordinate] - ghostLayers[loopCoordinate][1]
Martin Bauer's avatar
Martin Bauer committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1)
            lastLoop = newLoop
            currentBody = ast.Block([lastLoop])
        else:
            sliceComponent = iterationSlice[loopCoordinate]
            if type(sliceComponent) is slice:
                sc = sliceComponent
                newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step)
                lastLoop = newLoop
                currentBody = ast.Block([lastLoop])
            else:
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
                                                 sp.sympify(sliceComponent))
                currentBody.insertFront(assignment)
111
    return ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName)
112
113
114


def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Martin Bauer's avatar
Martin Bauer committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    r"""
    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.
    :param fieldAccess: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
    :param previousPtr: the pointer which is dereferenced
    :return: tuple with the new pointer symbol and the calculated offset

    Example:
        >>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
        >>> x, y = sp.symbols("x y")
129
        >>> prevPointer = TypedSymbol("ptr", "double")
Martin Bauer's avatar
Martin Bauer committed
130
131
132
133
134
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
        (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
        (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
    """
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    field = fieldAccess.field

    offset = 0
    name = ""
    listToHash = []
    for coordinateId, coordinateValue in coordinates.items():
        offset += field.strides[coordinateId] * coordinateValue

        if coordinateId < field.spatialDimensions:
            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
            if type(fieldAccess.offsets[coordinateId]) is int:
                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
                name += "_"
                name += offsetComp if offsetComp else "C"
            else:
                listToHash.append(fieldAccess.offsets[coordinateId])
        else:
            if type(coordinateValue) is int:
                name += "_%d" % (coordinateValue,)
            else:
                listToHash.append(coordinateValue)

    if len(listToHash) > 0:
        name += "%0.6X" % (abs(hash(tuple(listToHash))))

160
    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
161
162
163
164
165
    return newPtr, offset


def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
    """
Martin Bauer's avatar
Martin Bauer committed
166
167
168
169
170
171
172
    Creates base pointer specification for :func:`resolveFieldAccesses` function.

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
    This function translates more readable version into the specification above.

173
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
174
175
176
177
178
179
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

180
    :param basePointerSpecification: nested list with above specifications
181
    :param loopOrder: list with ordering of loops from outer to inner
182
    :param field:
Martin Bauer's avatar
Martin Bauer committed
183
    :return: list of tuples that can be passed to :func:`resolveFieldAccesses`
184
185
186
    """
    result = []
    specifiedCoordinates = set()
187
    loopOrder = list(reversed(loopOrder))
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    for specGroup in basePointerSpecification:
        newGroup = []

        def addNewElement(i):
            if i >= field.spatialDimensions + field.indexDimensions:
                raise ValueError("Coordinate %d does not exist" % (i,))
            newGroup.append(i)
            if i in specifiedCoordinates:
                raise ValueError("Coordinate %d specified two times" % (i,))
            specifiedCoordinates.add(i)
        for element in specGroup:
            if type(element) is int:
                addNewElement(element)
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
                    addNewElement(loopOrder[index])
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
                    addNewElement(loopOrder[-index])
                elif element == "all":
                    for i in range(field.spatialDimensions):
                        addNewElement(i)
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
                addNewElement(field.spatialDimensions + index)
            else:
                raise ValueError("Unknown specification %s" % (element,))

        result.append(newGroup)

    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
    rest = allCoordinates - specifiedCoordinates
    if rest:
        result.append(list(rest))
    return result


229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def substituteArrayAccessesWithConstants(astNode):
    """Substitutes all instances of Indexed (array acceses) that are not field accesses with constants.
    Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do 
    less optimizations.  
    This transformation should be after field accesses have been resolved (since they introduce array accesses) and 
    before constants are moved before the loops.
    """

    def handleSympyExpression(expr, parentBlock):
        """Returns sympy expression where array accesses have been replaced with constants, together with a list
        of assignments that define these constants"""
        if not isinstance(expr, sp.Expr):
            return expr

        # get all indexed expressions that are not field accesses
        indexedExprs = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]

        # special case: right hand side is a single indexed expression, then nothing has to be done
        if len(indexedExprs) == 1 and expr == indexedExprs[0]:
            return expr

        constantsDefinitions = []
        constantSubstitutions = {}
        for indexedExpr in indexedExprs:
            base, idx = indexedExpr.args
            typedSymbol = base.args[0]
            baseType = deepcopy(getBaseType(typedSymbol.dtype))
            baseType.const = False
            constantReplacingIndexed = TypedSymbol(typedSymbol.name + str(idx), baseType)
            constantsDefinitions.append(ast.SympyAssignment(constantReplacingIndexed, indexedExpr))
            constantSubstitutions[indexedExpr] = constantReplacingIndexed
        constantsDefinitions.sort(key=lambda e: e.lhs.name)

        alreadyDefined = parentBlock.symbolsDefined
        for newAssignment in constantsDefinitions:
            if newAssignment.lhs not in alreadyDefined:
                parentBlock.insertBefore(newAssignment, astNode)

        return expr.subs(constantSubstitutions)

    if isinstance(astNode, ast.SympyAssignment):
        astNode.rhs = handleSympyExpression(astNode.rhs, astNode.parent)
        astNode.lhs = handleSympyExpression(astNode.lhs, astNode.parent)
    elif isinstance(astNode, ast.LoopOverCoordinate):
        astNode.start = handleSympyExpression(astNode.start, astNode.parent)
        astNode.stop = handleSympyExpression(astNode.stop, astNode.parent)
        astNode.step = handleSympyExpression(astNode.step, astNode.parent)
        substituteArrayAccessesWithConstants(astNode.body)
    else:
        for a in astNode.args:
            substituteArrayAccessesWithConstants(a)
280
281


282
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
Martin Bauer's avatar
Martin Bauer committed
283
284
285
286
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

    :param astNode: the AST root
287
    :param readOnlyFieldNames: set of field names which are considered read-only
Martin Bauer's avatar
Martin Bauer committed
288
289
290
291
292
293
    :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created
                                   for details see :func:`parseBasePointerInfo`
    :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
                                    counters to index the field these symbols are used as coordinates
    :return: transformed AST
    """
294
295
296
    fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0]))
    fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0]))

297
    def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
298
299
300
301
302
303
304
305
        if isinstance(expr, Field.Access):
            fieldAccess = expr
            field = fieldAccess.field
            if field.name in fieldToBasePointerInfo:
                basePointerInfo = fieldToBasePointerInfo[field.name]
            else:
                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]

306
            dtype = PointerType(field.dtype, const=field.name in readOnlyFieldNames, restrict=True)
307
            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
308
309
310
311
312
313
314
315
316
317
318

            lastPointer = fieldPtr

            def createCoordinateDict(group):
                coordDict = {}
                for e in group:
                    if e < field.spatialDimensions:
                        if field.name in fieldToFixedCoordinates:
                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
                        else:
                            ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
319
                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
320
                        coordDict[e] *= field.dtype.itemSize
321
                    else:
322
323
324
325
326
327
328
                        if isinstance(field.dtype, StructType):
                            assert field.indexDimensions == 1
                            accessedFieldName = fieldAccess.index[0]
                            assert isinstance(accessedFieldName, str)
                            coordDict[e] = field.dtype.getElementOffset(accessedFieldName)
                        else:
                            coordDict[e] = fieldAccess.index[e - field.spatialDimensions]
329
330
331
332
333
334
                return coordDict

            for group in reversed(basePointerInfo[1:]):
                coordDict = createCoordinateDict(group)
                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
                if newPtr not in enclosingBlock.symbolsDefined:
335
336
                    newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)
                    enclosingBlock.insertBefore(newAssignment, sympyAssignment)
337
338
                lastPointer = newPtr

339
340
            coordDict = createCoordinateDict(basePointerInfo[0])
            _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
341
342
            result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index)

343
344
            if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
                newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
345
346
                result = castFunc(result, newType)
            return visitSympyExpr(result, enclosingBlock, sympyAssignment)
347
        else:
Martin Bauer's avatar
Martin Bauer committed
348
349
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
350

Martin Bauer's avatar
Martin Bauer committed
351
            newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
352
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
353
354
355
356
357
358
            return expr.func(*newArgs, **kwargs) if newArgs else expr

    def visitNode(subAst):
        if isinstance(subAst, ast.SympyAssignment):
            enclosingBlock = subAst.parent
            assert type(enclosingBlock) is ast.Block
359
360
            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
361
362
363
364
365
366
367
368
        else:
            for i, a in enumerate(subAst.args):
                visitNode(a)

    return visitNode(astNode)


def moveConstantsBeforeLoop(astNode):
Martin Bauer's avatar
Martin Bauer committed
369
370
371
372
373
374
    """
    Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
    Call this after creating the loop structure with :func:`makeLoopOverDomain`
    :param astNode:
    :return:
    """
375
    def findBlockToMoveTo(node):
Martin Bauer's avatar
Martin Bauer committed
376
377
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
378
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
379
380
381
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
382
383
384
385
        assert isinstance(node, ast.SympyAssignment)
        assert isinstance(node.parent, ast.Block)

        lastBlock = node.parent
Martin Bauer's avatar
Martin Bauer committed
386
        lastBlockChild = node
387
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
388
        prevElement = node
389
390
391
        while element:
            if isinstance(element, ast.Block):
                lastBlock = element
Martin Bauer's avatar
Martin Bauer committed
392
                lastBlockChild = prevElement
393
            if node.undefinedSymbols.intersection(element.symbolsDefined):
394
                break
Martin Bauer's avatar
Martin Bauer committed
395
            prevElement = element
396
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
397
        return lastBlock, lastBlockChild
398
399
400
401
402
403
404
405
406

    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
        for arg in targetBlock.args:
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

407
408
409
410
411
412
413
414
415
416
    def getBlocks(node, resultList):
        if isinstance(node, ast.Block):
            resultList.insert(0, node)
        if isinstance(node, ast.Node):
            for a in node.args:
                getBlocks(a, resultList)

    allBlocks = []
    getBlocks(astNode, allBlocks)
    for block in allBlocks:
417
418
419
420
421
        children = block.takeChildNodes()
        for child in children:
            if not isinstance(child, ast.SympyAssignment):
                block.append(child)
            else:
Martin Bauer's avatar
Martin Bauer committed
422
                target, childToInsertBefore = findBlockToMoveTo(child)
423
424
425
426
427
                if target == block:     # movement not possible
                    target.append(child)
                else:
                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
                    if not existingAssignment:
Martin Bauer's avatar
Martin Bauer committed
428
                        target.insertBefore(child, childToInsertBefore)
429
430
431
432
433
                    else:
                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"


def splitInnerLoop(astNode, symbolGroups):
Martin Bauer's avatar
Martin Bauer committed
434
435
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
436

Martin Bauer's avatar
Martin Bauer committed
437
438
    :param astNode: AST root
    :param symbolGroups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
439
440
         updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups and which
         no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
441
442
    :return: transformed AST
    """
443
444
445
446
447
448
449
450
451
    allLoops = astNode.atoms(ast.LoopOverCoordinate)
    innerLoop = [l for l in allLoops if l.isInnermostLoop]
    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    innerLoop = innerLoop[0]
    assert type(innerLoop.body) is ast.Block
    outerLoop = [l for l in allLoops if l.isOutermostLoop]
    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
    outerLoop = outerLoop[0]

452
453
    symbolsWithTemporaryArray = OrderedDict()
    assignmentMap = OrderedDict((a.lhs, a) for a in innerLoop.body.args)
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

    assignmentGroups = []
    for symbolGroup in symbolGroups:
        # get all dependent symbols
        symbolsToProcess = list(symbolGroup)
        symbolsResolved = set()
        while symbolsToProcess:
            s = symbolsToProcess.pop()
            if s in symbolsResolved:
                continue

            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
                        symbolsToProcess.append(newSymbol)
            symbolsResolved.add(s)

        for symbol in symbolGroup:
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
474
475
                newTs = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
476
477
478
479
480
481

        assignmentGroup = []
        for assignment in innerLoop.body.args:
            if assignment.lhs in symbolsResolved:
                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
482
483
484
                    assert type(assignment.lhs) is TypedSymbol
                    newTs = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
485
486
487
488
489
490
                else:
                    newLhs = assignment.lhs
                assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
        assignmentGroups.append(assignmentGroup)

    newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
491
    innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
492
493

    for tmpArray in symbolsWithTemporaryArray:
494
495
496
        tmpArrayPointer = TypedSymbol(tmpArray.name, PointerType(tmpArray.dtype))
        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop))
        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer))
497
498


499
500
501
502
503
def symbolNameToVariableName(symbolName):
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
    return symbolName.replace("^", "_")


504
def typeAllEquations(eqs, typeForSymbol):
Martin Bauer's avatar
Martin Bauer committed
505
506
507
508
509
510
511
512
513
    """
    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
    Additionally returns sets of all fields which are read/written

    :param eqs: list of equations
    :param typeForSymbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
    :return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations
               where symbols have been replaced by typed symbols
    """
514
515
516
517
518
    if not typeForSymbol or typeForSymbol == 'double':
        typeForSymbol = typingFromSympyInspection(eqs, "double")
    elif typeForSymbol == 'float':
        typeForSymbol = typingFromSympyInspection(eqs, "float")

519
520
521
522
523
524
525
526
527
528
    fieldsWritten = set()
    fieldsRead = set()

    def processRhs(term):
        """Replaces Symbols by:
            - TypedSymbol if symbol is not a field access
        """
        if isinstance(term, Field.Access):
            fieldsRead.add(term.field)
            return term
529
530
        elif isinstance(term, TypedSymbol):
            return term
531
        elif isinstance(term, sp.Symbol):
532
            return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
533
534
535
536
537
538
539
540
541
        else:
            newArgs = [processRhs(arg) for arg in term.args]
            return term.func(*newArgs) if newArgs else term

    def processLhs(term):
        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
        if isinstance(term, Field.Access):
            fieldsWritten.add(term.field)
            return term
542
543
        elif isinstance(term, TypedSymbol):
            return term
544
        elif isinstance(term, sp.Symbol):
545
            return TypedSymbol(term.name, typeForSymbol[term.name])
546
547
548
549
550
        else:
            assert False, "Expected a symbol as left-hand-side"

    typedEquations = []
    for eq in eqs:
Martin Bauer's avatar
Martin Bauer committed
551
        if isinstance(eq, sp.Eq) or isinstance(eq, ast.SympyAssignment):
552
553
554
555
556
557
558
559
560
561
562
563
            newLhs = processLhs(eq.lhs)
            newRhs = processRhs(eq.rhs)
            typedEquations.append(ast.SympyAssignment(newLhs, newRhs))
        else:
            assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input"
            typedEquations.append(eq)

    typedEquations = typedEquations

    return fieldsRead, fieldsWritten, typedEquations


Martin Bauer's avatar
Martin Bauer committed
564
565
566
# --------------------------------------- Helper Functions -------------------------------------------------------------


567
def typingFromSympyInspection(eqs, defaultType="double"):
Martin Bauer's avatar
Martin Bauer committed
568
569
570
571
572
573
574
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
    :param eqs: list of equations
    :param defaultType: the type for non-boolean symbols
    :return: dictionary, mapping symbol name to type
    """
575
576
    result = defaultdict(lambda: defaultType)
    for eq in eqs:
577
578
        if isinstance(eq, ast.Node):
            continue
579
580
581
        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
        # further information what type the left hand side is - default fallback is the dict value then
        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
582
583
584
585
586
            result[eq.lhs.name] = "bool"
    return result


def getNextParentOfType(node, parentType):
Martin Bauer's avatar
Martin Bauer committed
587
588
589
    """
    Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
    """
590
591
592
593
594
595
596
597
598
    parent = node.parent
    while parent is not None:
        if isinstance(parent, parentType):
            return parent
        parent = parent.parent
    return None


def getOptimalLoopOrdering(fields):
Martin Bauer's avatar
Martin Bauer committed
599
600
601
602
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
    :param fields: sequence of fields
603
    :return: list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
604
    """
605
606
607
608
    assert len(fields) > 0
    refField = next(iter(fields))
    for field in fields:
        if field.spatialDimensions != refField.spatialDimensions:
609
610
            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
                             + str({f.name: f.spatialShape for f in fields}))
611
612
613

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
614
615
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " +
                         str({f.name: f.layout for f in fields}))
616
    layout = list(layouts)[0]
617
    return list(layout)
618
619


Martin Bauer's avatar
Martin Bauer committed
620
621
622
623
624
def getLoopHierarchy(astNode):
    """Determines the loop structure around a given AST node.
    :param astNode: the AST node
    :return: list of coordinate ids, where the first list entry is the innermost loop
    """
625
    result = []
Martin Bauer's avatar
Martin Bauer committed
626
    node = astNode
627
628
629
630
    while node is not None:
        node = getNextParentOfType(node, ast.LoopOverCoordinate)
        if node:
            result.append(node.coordinateToLoopOver)
631
632
    return reversed(result)

Jan Hoenig's avatar
Jan Hoenig committed
633

Jan Hoenig's avatar
Jan Hoenig committed
634
635
636
637
638
639
640
641
def get_type(node):
    if isinstance(node, ast.Indexed):
        return node.args[0].dtype
    elif isinstance(node, ast.Node):
        return node.dtype
    # TODO sp.NumberSymbol
    elif isinstance(node, sp.Number):
        if isinstance(node, sp.Float):
642
            return createType('double')
Jan Hoenig's avatar
Jan Hoenig committed
643
        elif isinstance(node, sp.Integer):
644
            return createType('int')
Jan Hoenig's avatar
Jan Hoenig committed
645
646
647
648
649
650
        else:
            raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
    else:
        raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))


Jan Hoenig's avatar
Jan Hoenig committed
651
def insert_casts(node):
Jan Hoenig's avatar
Jan Hoenig committed
652
    """
Martin Bauer's avatar
Martin Bauer committed
653
    Inserts casts and dtype whpere needed
Jan Hoenig's avatar
Jan Hoenig committed
654
655
656
    :param node: ast which should be traversed
    :return: node
    """
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    def conversion(args):
        target = args[0]
        if isinstance(target.dtype, PointerType):
            # Pointer arithmetic
            for arg in args[1:]:
                # Check validness
                if not arg.dtype.is_int() and not arg.dtype.is_uint():
                    raise ValueError("Impossible pointer arithmetic", target, arg)
            pointer = ast.PointerArithmetic(ast.Add(args[1:]), target)
            return [pointer]

        else:
            for i in range(len(args)):
                if args[i].dtype != target.dtype:
                    args[i] = ast.Conversion(args[i], target.dtype, node)
            return args
Jan Hoenig's avatar
Jan Hoenig committed
673
674
675
676

    for arg in node.args:
        insert_casts(arg)
    if isinstance(node, ast.Indexed):
Jan Hoenig's avatar
Jan Hoenig committed
677
        #TODO revmove this
Jan Hoenig's avatar
Jan Hoenig committed
678
        pass
Jan Hoenig's avatar
Jan Hoenig committed
679
    elif isinstance(node, ast.Expr):
680
681
682
        #print(node, node.args)
        #print([type(arg) for arg in node.args])
        #print([arg.dtype for arg in node.args])
Jan Hoenig's avatar
Jan Hoenig committed
683
        args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
Jan Hoenig's avatar
Jan Hoenig committed
684
        target = args[0]
685
        node.args = conversion(args)
Jan Hoenig's avatar
Jan Hoenig committed
686
        node.dtype = target.dtype
687
688
        #print(node.dtype)
        #print(node)
Jan Hoenig's avatar
Jan Hoenig committed
689
690
691
    elif isinstance(node, ast.SympyAssignment):
        if node.lhs.dtype != node.rhs.dtype:
            node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
Jan Hoenig's avatar
Jan Hoenig committed
692
    elif isinstance(node, ast.LoopOverCoordinate):
Jan Hoenig's avatar
Jan Hoenig committed
693
        pass
Jan Hoenig's avatar
Jan Hoenig committed
694
    return node
Jan Hoenig's avatar
Jan Hoenig committed
695
696
697


def desympy_ast(node):
Jan Hoenig's avatar
Jan Hoenig committed
698
699
700
701
702
703
    """
    Remove Sympy Expressions, which have more then one argument.
    This is necessary for further changes in the tree.
    :param node: ast which should be traversed. Only node's children will be modified.
    :return: (modified) node
    """
Jan Hoenig's avatar
Jan Hoenig committed
704
705
    if node.args is None:
        return node
Jan Hoenig's avatar
Jan Hoenig committed
706
707
708
709
    for i in range(len(node.args)):
        arg = node.args[i]
        if isinstance(arg, sp.Add):
            node.replace(arg, ast.Add(arg.args, node))
Jan Hoenig's avatar
Jan Hoenig committed
710
711
        elif isinstance(arg, sp.Number):
            node.replace(arg, ast.Number(arg, node))
Jan Hoenig's avatar
Jan Hoenig committed
712
713
714
715
        elif isinstance(arg, sp.Mul):
            node.replace(arg, ast.Mul(arg.args, node))
        elif isinstance(arg, sp.Pow):
            node.replace(arg, ast.Pow(arg.args, node))
716
717
718
719
720
721
722
        elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed):
            node.replace(arg, ast.Indexed(arg.args, arg.base, node))
        elif isinstance(arg,  sp.tensor.IndexedBase):
            node.replace(arg, arg.label)
        #elif isinstance(arg, sp.containers.Tuple):
        #
        else:
723
724
            #print('Not transforming:', type(arg), arg)
            pass
Jan Hoenig's avatar
Jan Hoenig committed
725
726
    for arg in node.args:
        desympy_ast(arg)
Jan Hoenig's avatar
Jan Hoenig committed
727
    return node
728
729
730
731
732
733
734
735
736
737
738
739


def check_dtype(node):
    if isinstance(node, ast.KernelFunction):
        pass
    elif isinstance(node, ast.Block):
        pass
    elif isinstance(node, ast.LoopOverCoordinate):
        pass
    elif isinstance(node, ast.SympyAssignment):
        pass
    else:
740
741
742
        #print(node)
        #print(node.dtype)
        pass
743
744
745
    for arg in node.args:
        check_dtype(arg)