transformations.py 39.6 KB
Newer Older
1
from collections import defaultdict, OrderedDict
Jan Hoenig's avatar
Jan Hoenig committed
2
from operator import attrgetter
3
from copy import deepcopy
Martin Bauer's avatar
Martin Bauer committed
4
import functools
Jan Hoenig's avatar
Jan Hoenig committed
5

6
7
import sympy as sp
from sympy.logic.boolalg import Boolean
8
from sympy.tensor import IndexedBase
Martin Bauer's avatar
Martin Bauer committed
9

10
from pystencils.field import Field, FieldType, offsetComponentToDirectionString
11
from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc
Martin Bauer's avatar
Martin Bauer committed
12
from pystencils.slicing import normalizeSlice
Martin Bauer's avatar
Martin Bauer committed
13
import pystencils.astnodes as ast
14
15


16
17
18
19
20
21
22
def filteredTreeIteration(node, nodeType):
    for arg in node.args:
        if isinstance(arg, nodeType):
            yield arg
        yield from filteredTreeIteration(arg, nodeType)


23
24
25
def fastSubs(term, subsDict):
    """Similar to sympy subs function.
    This version is much faster for big substitution dictionaries than sympy version"""
Martin Bauer's avatar
Martin Bauer committed
26
27
28
    if type(term) is sp.Matrix:
        return term.copy().applyfunc(functools.partial(fastSubs, subsDict=subsDict))

29
30
31
    def visit(expr):
        if expr in subsDict:
            return subsDict[expr]
Martin Bauer's avatar
Martin Bauer committed
32
33
        if not hasattr(expr, 'args'):
            return expr
34
35
36
37
38
        paramList = [visit(a) for a in expr.args]
        return expr if not paramList else expr.func(*paramList)
    return visit(term)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def getCommonShape(fieldSet):
    """Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise
    ValueError is raised"""
    nrOfFixedShapedFields = 0
    for f in fieldSet:
        if f.hasFixedShape:
            nrOfFixedShapedFields += 1

    if nrOfFixedShapedFields > 0 and nrOfFixedShapedFields != len(fieldSet):
        fixedFieldNames = ",".join([f.name for f in fieldSet if f.hasFixedShape])
        varFieldNames = ",".join([f.name for f in fieldSet if not f.hasFixedShape])
        msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
        msg += "Variable shaped: %s \nFixed shaped:    %s" % (varFieldNames, fixedFieldNames)
        raise ValueError(msg)

    shapeSet = set([f.spatialShape for f in fieldSet])
    if nrOfFixedShapedFields == len(fieldSet):
        if len(shapeSet) != 1:
            raise ValueError("Differently sized field accesses in loop body: " + str(shapeSet))

    shape = list(sorted(shapeSet, key=lambda e: str(e[0])))[0]
    return shape


63
def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None, loopOrder=None):
64
    """
Martin Bauer's avatar
Martin Bauer committed
65
    Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
66

67
68
    :param body: list of nodes
    :param functionName: name of generated C function
Martin Bauer's avatar
Martin Bauer committed
69
    :param iterationSlice: if not None, iteration is done only over this slice of the field
70
    :param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
71
72
                if None, the number of ghost layers is determined automatically and assumed to be equal for a
                all dimensions
73
    :param loopOrder: loop ordering from outer to inner loop (optimal ordering is same as layout)
Martin Bauer's avatar
Martin Bauer committed
74
    :return: :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
75
76
77
    """
    # find correct ordering by inspecting participating FieldAccesses
    fieldAccesses = body.atoms(Field.Access)
78
79
    # exclude accesses to buffers from fieldList, because buffers are treated separately
    fieldList = [e.field for e in fieldAccesses if not FieldType.isBuffer(e.field)]
80
    fields = set(fieldList)
81
    numBufferAccesses = len(fieldAccesses) - len(fieldList)
82
83
84

    if loopOrder is None:
        loopOrder = getOptimalLoopOrdering(fields)
85

86
    shape = getCommonShape(list(fields))
87

Martin Bauer's avatar
Martin Bauer committed
88
89
90
    if iterationSlice is not None:
        iterationSlice = normalizeSlice(iterationSlice, shape)

91
92
93
    if ghostLayers is None:
        requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
        ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(loopOrder)
94
95
    if isinstance(ghostLayers, int):
        ghostLayers = [(ghostLayers, ghostLayers)] * len(loopOrder)
96

97
98
99
100
101
    def getLoopStride(begin, end, step):
        return (end - begin) / step

    loopStrides = []
    loopVars = []
102
103
    currentBody = body
    lastLoop = None
104
    for i, loopCoordinate in enumerate(reversed(loopOrder)):
Martin Bauer's avatar
Martin Bauer committed
105
        if iterationSlice is None:
106
107
            begin = ghostLayers[loopCoordinate][0]
            end = shape[loopCoordinate] - ghostLayers[loopCoordinate][1]
Martin Bauer's avatar
Martin Bauer committed
108
109
110
            newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1)
            lastLoop = newLoop
            currentBody = ast.Block([lastLoop])
111
112
            loopStrides.append(getLoopStride(begin, end, 1))
            loopVars.append(newLoop.loopCounterSymbol)
Martin Bauer's avatar
Martin Bauer committed
113
114
115
116
117
118
119
        else:
            sliceComponent = iterationSlice[loopCoordinate]
            if type(sliceComponent) is slice:
                sc = sliceComponent
                newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step)
                lastLoop = newLoop
                currentBody = ast.Block([lastLoop])
120
121
                loopStrides.append(getLoopStride(sc.start, sc.stop, sc.step))
                loopVars.append(newLoop.loopCounterSymbol)
Martin Bauer's avatar
Martin Bauer committed
122
123
124
125
            else:
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
                                                 sp.sympify(sliceComponent))
                currentBody.insertFront(assignment)
126
127

    loopVars = [numBufferAccesses * var for var in loopVars]
128
    astNode = ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName, backend='cpu')
129
    return (astNode, loopStrides, loopVars)
130
131
132


def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Martin Bauer's avatar
Martin Bauer committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    r"""
    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.
    :param fieldAccess: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
    :param previousPtr: the pointer which is dereferenced
    :return: tuple with the new pointer symbol and the calculated offset

    Example:
        >>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
        >>> x, y = sp.symbols("x y")
147
        >>> prevPointer = TypedSymbol("ptr", "double")
Martin Bauer's avatar
Martin Bauer committed
148
149
150
151
152
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
        (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
        (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
    """
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    field = fieldAccess.field
    offset = 0
    name = ""
    listToHash = []
    for coordinateId, coordinateValue in coordinates.items():
        offset += field.strides[coordinateId] * coordinateValue

        if coordinateId < field.spatialDimensions:
            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
            if type(fieldAccess.offsets[coordinateId]) is int:
                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
                name += "_"
                name += offsetComp if offsetComp else "C"
            else:
                listToHash.append(fieldAccess.offsets[coordinateId])
        else:
            if type(coordinateValue) is int:
                name += "_%d" % (coordinateValue,)
            else:
                listToHash.append(coordinateValue)

    if len(listToHash) > 0:
        name += "%0.6X" % (abs(hash(tuple(listToHash))))

177
    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
178

179
180
181
182
183
    return newPtr, offset


def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
    """
Martin Bauer's avatar
Martin Bauer committed
184
185
186
187
188
189
190
    Creates base pointer specification for :func:`resolveFieldAccesses` function.

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
    This function translates more readable version into the specification above.

191
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
192
193
194
195
196
197
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

198
    :param basePointerSpecification: nested list with above specifications
199
    :param loopOrder: list with ordering of loops from outer to inner
200
    :param field:
Martin Bauer's avatar
Martin Bauer committed
201
    :return: list of tuples that can be passed to :func:`resolveFieldAccesses`
202
203
204
    """
    result = []
    specifiedCoordinates = set()
205
    loopOrder = list(reversed(loopOrder))
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    for specGroup in basePointerSpecification:
        newGroup = []

        def addNewElement(i):
            if i >= field.spatialDimensions + field.indexDimensions:
                raise ValueError("Coordinate %d does not exist" % (i,))
            newGroup.append(i)
            if i in specifiedCoordinates:
                raise ValueError("Coordinate %d specified two times" % (i,))
            specifiedCoordinates.add(i)
        for element in specGroup:
            if type(element) is int:
                addNewElement(element)
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
                    addNewElement(loopOrder[index])
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
                    addNewElement(loopOrder[-index])
                elif element == "all":
                    for i in range(field.spatialDimensions):
                        addNewElement(i)
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
                addNewElement(field.spatialDimensions + index)
            else:
                raise ValueError("Unknown specification %s" % (element,))

        result.append(newGroup)

    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
    rest = allCoordinates - specifiedCoordinates
    if rest:
        result.append(list(rest))
244

245
246
247
    return result


248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def substituteArrayAccessesWithConstants(astNode):
    """Substitutes all instances of Indexed (array acceses) that are not field accesses with constants.
    Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do 
    less optimizations.  
    This transformation should be after field accesses have been resolved (since they introduce array accesses) and 
    before constants are moved before the loops.
    """

    def handleSympyExpression(expr, parentBlock):
        """Returns sympy expression where array accesses have been replaced with constants, together with a list
        of assignments that define these constants"""
        if not isinstance(expr, sp.Expr):
            return expr

        # get all indexed expressions that are not field accesses
        indexedExprs = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]

        # special case: right hand side is a single indexed expression, then nothing has to be done
        if len(indexedExprs) == 1 and expr == indexedExprs[0]:
            return expr

        constantsDefinitions = []
        constantSubstitutions = {}
        for indexedExpr in indexedExprs:
            base, idx = indexedExpr.args
            typedSymbol = base.args[0]
            baseType = deepcopy(getBaseType(typedSymbol.dtype))
            baseType.const = False
            constantReplacingIndexed = TypedSymbol(typedSymbol.name + str(idx), baseType)
            constantsDefinitions.append(ast.SympyAssignment(constantReplacingIndexed, indexedExpr))
            constantSubstitutions[indexedExpr] = constantReplacingIndexed
        constantsDefinitions.sort(key=lambda e: e.lhs.name)

        alreadyDefined = parentBlock.symbolsDefined
        for newAssignment in constantsDefinitions:
            if newAssignment.lhs not in alreadyDefined:
                parentBlock.insertBefore(newAssignment, astNode)

        return expr.subs(constantSubstitutions)

    if isinstance(astNode, ast.SympyAssignment):
        astNode.rhs = handleSympyExpression(astNode.rhs, astNode.parent)
        astNode.lhs = handleSympyExpression(astNode.lhs, astNode.parent)
    elif isinstance(astNode, ast.LoopOverCoordinate):
        astNode.start = handleSympyExpression(astNode.start, astNode.parent)
        astNode.stop = handleSympyExpression(astNode.stop, astNode.parent)
        astNode.step = handleSympyExpression(astNode.step, astNode.parent)
        substituteArrayAccessesWithConstants(astNode.body)
    else:
        for a in astNode.args:
            substituteArrayAccessesWithConstants(a)
299

Martin Bauer's avatar
Martin Bauer committed
300

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def resolveBufferAccesses(astNode, baseBufferIndex, readOnlyFieldNames=set()):
    def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
        if isinstance(expr, Field.Access):
            fieldAccess = expr

            # Do not apply transformation if field is not a buffer
            if not FieldType.isBuffer(fieldAccess.field):
                return expr

            buffer = fieldAccess.field

            dtype = PointerType(buffer.dtype, const=buffer.name in readOnlyFieldNames, restrict=True)
            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(buffer.name)), dtype)

            bufferIndex = baseBufferIndex
            if len(fieldAccess.index) > 1:
                raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')

            if len(fieldAccess.index) > 0:
                cellIndex = fieldAccess.index[0]
                bufferIndex += cellIndex

            result = ast.ResolvedFieldAccess(fieldPtr, bufferIndex, fieldAccess.field, fieldAccess.offsets,
                                             fieldAccess.index)

            return visitSympyExpr(result, enclosingBlock, sympyAssignment)
        else:
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr

            newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
            return expr.func(*newArgs, **kwargs) if newArgs else expr

    def visitNode(subAst):
        if isinstance(subAst, ast.SympyAssignment):
            enclosingBlock = subAst.parent
            assert type(enclosingBlock) is ast.Block
            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
        else:
            for i, a in enumerate(subAst.args):
                visitNode(a)

    return visitNode(astNode)

347

348
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
Martin Bauer's avatar
Martin Bauer committed
349
350
351
352
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

    :param astNode: the AST root
353
    :param readOnlyFieldNames: set of field names which are considered read-only
Martin Bauer's avatar
Martin Bauer committed
354
355
356
357
358
359
    :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created
                                   for details see :func:`parseBasePointerInfo`
    :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
                                    counters to index the field these symbols are used as coordinates
    :return: transformed AST
    """
360
361
362
    fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0]))
    fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0]))

363
    def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
364
365
366
        if isinstance(expr, Field.Access):
            fieldAccess = expr
            field = fieldAccess.field
367

368
369
370
371
372
            if field.name in fieldToBasePointerInfo:
                basePointerInfo = fieldToBasePointerInfo[field.name]
            else:
                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]

373
            dtype = PointerType(field.dtype, const=field.name in readOnlyFieldNames, restrict=True)
374
            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(field.name)), dtype)
375
376
377
378
379
380
381
382
383

            def createCoordinateDict(group):
                coordDict = {}
                for e in group:
                    if e < field.spatialDimensions:
                        if field.name in fieldToFixedCoordinates:
                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
                        else:
                            ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
384
                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), 'int')
385
                        coordDict[e] *= field.dtype.itemSize
386
                    else:
387
388
389
390
391
392
393
                        if isinstance(field.dtype, StructType):
                            assert field.indexDimensions == 1
                            accessedFieldName = fieldAccess.index[0]
                            assert isinstance(accessedFieldName, str)
                            coordDict[e] = field.dtype.getElementOffset(accessedFieldName)
                        else:
                            coordDict[e] = fieldAccess.index[e - field.spatialDimensions]
394

395
396
                return coordDict

397
398
            lastPointer = fieldPtr

399
400
401
402
            for group in reversed(basePointerInfo[1:]):
                coordDict = createCoordinateDict(group)
                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
                if newPtr not in enclosingBlock.symbolsDefined:
403
404
                    newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)
                    enclosingBlock.insertBefore(newAssignment, sympyAssignment)
405
406
                lastPointer = newPtr

407
            coordDict = createCoordinateDict(basePointerInfo[0])
408

409
            _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
410
411
            result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field,
                                             fieldAccess.offsets, fieldAccess.index)
412

413
414
            if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
                newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
415
                result = castFunc(result, newType)
416

417
            return visitSympyExpr(result, enclosingBlock, sympyAssignment)
418
        else:
Martin Bauer's avatar
Martin Bauer committed
419
420
            if isinstance(expr, ast.ResolvedFieldAccess):
                return expr
421

Martin Bauer's avatar
Martin Bauer committed
422
            newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
423
            kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
424
425
426
427
428
429
            return expr.func(*newArgs, **kwargs) if newArgs else expr

    def visitNode(subAst):
        if isinstance(subAst, ast.SympyAssignment):
            enclosingBlock = subAst.parent
            assert type(enclosingBlock) is ast.Block
430
431
            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
432
433
434
435
436
437
438
439
        else:
            for i, a in enumerate(subAst.args):
                visitNode(a)

    return visitNode(astNode)


def moveConstantsBeforeLoop(astNode):
Martin Bauer's avatar
Martin Bauer committed
440
441
442
443
444
445
    """
    Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
    Call this after creating the loop structure with :func:`makeLoopOverDomain`
    :param astNode:
    :return:
    """
446
    def findBlockToMoveTo(node):
Martin Bauer's avatar
Martin Bauer committed
447
448
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
449
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
450
451
452
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
453
454
455
456
        assert isinstance(node, ast.SympyAssignment)
        assert isinstance(node.parent, ast.Block)

        lastBlock = node.parent
Martin Bauer's avatar
Martin Bauer committed
457
        lastBlockChild = node
458
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
459
        prevElement = node
460
461
462
        while element:
            if isinstance(element, ast.Block):
                lastBlock = element
Martin Bauer's avatar
Martin Bauer committed
463
                lastBlockChild = prevElement
464
465
466
467
468
469

            if isinstance(element, ast.Conditional):
                criticalSymbols = element.conditionExpr.atoms(sp.Symbol)
            else:
                criticalSymbols = element.symbolsDefined
            if node.undefinedSymbols.intersection(criticalSymbols):
470
                break
Martin Bauer's avatar
Martin Bauer committed
471
            prevElement = element
472
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
473
        return lastBlock, lastBlockChild
474
475
476
477
478
479
480
481
482

    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
        for arg in targetBlock.args:
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

483
484
485
486
487
488
489
490
491
492
    def getBlocks(node, resultList):
        if isinstance(node, ast.Block):
            resultList.insert(0, node)
        if isinstance(node, ast.Node):
            for a in node.args:
                getBlocks(a, resultList)

    allBlocks = []
    getBlocks(astNode, allBlocks)
    for block in allBlocks:
493
494
495
496
497
        children = block.takeChildNodes()
        for child in children:
            if not isinstance(child, ast.SympyAssignment):
                block.append(child)
            else:
Martin Bauer's avatar
Martin Bauer committed
498
                target, childToInsertBefore = findBlockToMoveTo(child)
499
500
501
502
503
                if target == block:     # movement not possible
                    target.append(child)
                else:
                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
                    if not existingAssignment:
Martin Bauer's avatar
Martin Bauer committed
504
                        target.insertBefore(child, childToInsertBefore)
505
506
507
508
509
                    else:
                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"


def splitInnerLoop(astNode, symbolGroups):
Martin Bauer's avatar
Martin Bauer committed
510
511
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
512

Martin Bauer's avatar
Martin Bauer committed
513
514
    :param astNode: AST root
    :param symbolGroups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
515
516
         updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups and which
         no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
517
518
    :return: transformed AST
    """
519
520
521
522
523
524
525
526
527
    allLoops = astNode.atoms(ast.LoopOverCoordinate)
    innerLoop = [l for l in allLoops if l.isInnermostLoop]
    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    innerLoop = innerLoop[0]
    assert type(innerLoop.body) is ast.Block
    outerLoop = [l for l in allLoops if l.isOutermostLoop]
    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
    outerLoop = outerLoop[0]

528
529
    symbolsWithTemporaryArray = OrderedDict()
    assignmentMap = OrderedDict((a.lhs, a) for a in innerLoop.body.args)
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549

    assignmentGroups = []
    for symbolGroup in symbolGroups:
        # get all dependent symbols
        symbolsToProcess = list(symbolGroup)
        symbolsResolved = set()
        while symbolsToProcess:
            s = symbolsToProcess.pop()
            if s in symbolsResolved:
                continue

            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
                        symbolsToProcess.append(newSymbol)
            symbolsResolved.add(s)

        for symbol in symbolGroup:
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
550
551
                newTs = TypedSymbol(symbol.name, PointerType(symbol.dtype))
                symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
552
553
554
555
556
557

        assignmentGroup = []
        for assignment in innerLoop.body.args:
            if assignment.lhs in symbolsResolved:
                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
558
559
560
                    assert type(assignment.lhs) is TypedSymbol
                    newTs = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
                    newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
561
562
563
564
565
566
                else:
                    newLhs = assignment.lhs
                assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
        assignmentGroups.append(assignmentGroup)

    newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
567
    innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
568
569

    for tmpArray in symbolsWithTemporaryArray:
570
571
572
        tmpArrayPointer = TypedSymbol(tmpArray.name, PointerType(tmpArray.dtype))
        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop))
        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer))
573
574


575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
def cutLoop(loopNode, cuttingPoints):
    """Cuts loop at given cutting points, that means one loop is transformed into len(cuttingPoints)+1 new loops
    that range from  oldBegin to cuttingPoint[1], ..., cuttingPoint[-1] to oldEnd"""
    if loopNode.step != 1:
        raise NotImplementedError("Can only split loops that have a step of 1")
    newLoops = []
    newStart = loopNode.start
    cuttingPoints = list(cuttingPoints) + [loopNode.stop]
    for newEnd in cuttingPoints:
        if newEnd - newStart == 1:
            newBody = deepcopy(loopNode.body)
            newBody.subs({loopNode.loopCounterSymbol: newStart})
            newLoops.append(newBody)
        else:
            newLoop = ast.LoopOverCoordinate(deepcopy(loopNode.body), loopNode.coordinateToLoopOver,
                                             newStart, newEnd, loopNode.step)
            newLoops.append(newLoop)
        newStart = newEnd
    loopNode.parent.replace(loopNode, newLoops)


def isConditionNecessary(condition, preCondition, symbol):
    """
    Determines if a logical condition of a single variable is already contained in a stronger preCondition
    so if from preCondition follows that condition is always true, then this condition is not necessary
    :param condition: sympy relational of one variable
    :param preCondition: logical expression that is known to be true
    :param symbol: the single symbol of interest
    :return: returns  not (preCondition => condition) where "=>" is logical implication
    """
    from sympy.solvers.inequalities import reduce_rational_inequalities
    from sympy.logic.boolalg import to_dnf

    def toDnfList(expr):
        result = to_dnf(expr)
        if isinstance(result, sp.Or):
            return [orTerm.args for orTerm in result.args]
        elif isinstance(result, sp.And):
            return [result.args]
        else:
            return result

    t1 = reduce_rational_inequalities(toDnfList(sp.And(condition, preCondition)), symbol)
    t2 = reduce_rational_inequalities(toDnfList(preCondition), symbol)
    return t1 != t2


def simplifyBooleanExpression(expr, singleVariableRanges):
    """Simplification of boolean expression using known ranges of variables
    The singleVariableRanges parameter is a dict mapping a variable name to a sympy logical expression that
    contains only this variable and defines a range for it. For example with a being a symbol
    { a: sp.And(a >=0, a < 10) }
    """
    from sympy.core.relational import Relational
    from sympy.logic.boolalg import to_dnf

    expr = to_dnf(expr)

    def visit(e):
        if isinstance(e, Relational):
            symbols = e.atoms(sp.Symbol)
            if len(symbols) == 1:
                symbol = symbols.pop()
                if symbol in singleVariableRanges:
                    if not isConditionNecessary(e, singleVariableRanges[symbol], symbol):
                        return sp.true
            return e
        else:
            newArgs = [visit(a) for a in e.args]
            return e.func(*newArgs) if newArgs else e

    return visit(expr)


def simplifyConditionals(node, loopConditionals={}):
    """Simplifies/Removes conditions inside loops that depend on the loop counter."""
    if isinstance(node, ast.LoopOverCoordinate):
        ctrSym = node.loopCounterSymbol
        loopConditionals[ctrSym] = sp.And(ctrSym >= node.start, ctrSym < node.stop)
        simplifyConditionals(node.body)
        del loopConditionals[ctrSym]
    elif isinstance(node, ast.Conditional):
        node.conditionExpr = simplifyBooleanExpression(node.conditionExpr, loopConditionals)
        simplifyConditionals(node.trueBlock)
        if node.falseBlock:
            simplifyConditionals(node.falseBlock)
        if node.conditionExpr == sp.true:
            node.parent.replace(node, [node.trueBlock])
        if node.conditionExpr == sp.false:
            node.parent.replace(node, [node.falseBlock] if node.falseBlock else [])
    elif isinstance(node, ast.Block):
        for a in list(node.args):
            simplifyConditionals(a)
    elif isinstance(node, ast.SympyAssignment):
        return node
    else:
        raise ValueError("Can not handle node", type(node))


def cleanupBlocks(node):
    """Curly Brace Removal: Removes empty blocks, and replaces blocks with a single child by its child """
    if isinstance(node, ast.SympyAssignment):
        return
    elif isinstance(node, ast.Block):
        for a in list(node.args):
            cleanupBlocks(a)
        if len(node.args) <= 1 and isinstance(node.parent, ast.Block):
            node.parent.replace(node, node.args)
            return
    else:
        for a in node.args:
            cleanupBlocks(a)


689
690
691
692
693
def symbolNameToVariableName(symbolName):
    """Replaces characters which are allowed in sympy symbol names but not in C/C++ variable names"""
    return symbolName.replace("^", "_")


694
def typeAllEquations(eqs, typeForSymbol):
Martin Bauer's avatar
Martin Bauer committed
695
696
697
698
699
700
701
702
703
    """
    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
    Additionally returns sets of all fields which are read/written

    :param eqs: list of equations
    :param typeForSymbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
    :return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations
               where symbols have been replaced by typed symbols
    """
704
705
    if isinstance(typeForSymbol, str) or not hasattr(typeForSymbol, '__getitem__'):
        typeForSymbol = typingFromSympyInspection(eqs, typeForSymbol)
706

707
708
709
710
711
712
713
714
715
716
    fieldsWritten = set()
    fieldsRead = set()

    def processRhs(term):
        """Replaces Symbols by:
            - TypedSymbol if symbol is not a field access
        """
        if isinstance(term, Field.Access):
            fieldsRead.add(term.field)
            return term
717
718
        elif isinstance(term, TypedSymbol):
            return term
719
        elif isinstance(term, sp.Symbol):
720
            return TypedSymbol(symbolNameToVariableName(term.name), typeForSymbol[term.name])
721
722
723
724
725
726
727
728
729
        else:
            newArgs = [processRhs(arg) for arg in term.args]
            return term.func(*newArgs) if newArgs else term

    def processLhs(term):
        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
        if isinstance(term, Field.Access):
            fieldsWritten.add(term.field)
            return term
730
731
        elif isinstance(term, TypedSymbol):
            return term
732
        elif isinstance(term, sp.Symbol):
733
            return TypedSymbol(term.name, typeForSymbol[term.name])
734
735
736
        else:
            assert False, "Expected a symbol as left-hand-side"

737
738
739
740
741
742
743
744
745
746
747
748
749
    def visit(object):
        if isinstance(object, list) or isinstance(object, tuple):
            return [visit(e) for e in object]
        if isinstance(object, sp.Eq) or isinstance(object, ast.SympyAssignment):
            newLhs = processLhs(object.lhs)
            newRhs = processRhs(object.rhs)
            return ast.SympyAssignment(newLhs, newRhs)
        elif isinstance(object, ast.Conditional):
            falseBlock = None if object.falseBlock is None else visit(object.falseBlock)
            return ast.Conditional(processRhs(object.conditionExpr),
                                   trueBlock=visit(object.trueBlock), falseBlock=falseBlock)
        elif isinstance(object, ast.Block):
            return ast.Block([visit(e) for e in object.args])
750
        else:
751
            return object
752

753
    typedEquations = visit(eqs)
754
755
756
757

    return fieldsRead, fieldsWritten, typedEquations


Martin Bauer's avatar
Martin Bauer committed
758
759
760
# --------------------------------------- Helper Functions -------------------------------------------------------------


761
def typingFromSympyInspection(eqs, defaultType="double"):
Martin Bauer's avatar
Martin Bauer committed
762
763
764
765
766
767
768
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
    :param eqs: list of equations
    :param defaultType: the type for non-boolean symbols
    :return: dictionary, mapping symbol name to type
    """
769
770
    result = defaultdict(lambda: defaultType)
    for eq in eqs:
771
772
        if isinstance(eq, ast.Node):
            continue
773
774
775
        # problematic case here is when rhs is a symbol: then it is impossible to decide here without
        # further information what type the left hand side is - default fallback is the dict value then
        if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
776
777
778
779
780
            result[eq.lhs.name] = "bool"
    return result


def getNextParentOfType(node, parentType):
Martin Bauer's avatar
Martin Bauer committed
781
782
783
    """
    Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
    """
784
785
786
787
788
789
790
791
792
    parent = node.parent
    while parent is not None:
        if isinstance(parent, parentType):
            return parent
        parent = parent.parent
    return None


def getOptimalLoopOrdering(fields):
Martin Bauer's avatar
Martin Bauer committed
793
794
795
796
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
    :param fields: sequence of fields
797
    :return: list of coordinate ids, where the first list entry should be the outermost loop
Martin Bauer's avatar
Martin Bauer committed
798
    """
799
800
801
802
    assert len(fields) > 0
    refField = next(iter(fields))
    for field in fields:
        if field.spatialDimensions != refField.spatialDimensions:
803
804
            raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
                             + str({f.name: f.spatialShape for f in fields}))
805
806
807

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
808
809
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " +
                         str({f.name: f.layout for f in fields}))
810
    layout = list(layouts)[0]
811
    return list(layout)
812
813


Martin Bauer's avatar
Martin Bauer committed
814
815
816
817
818
def getLoopHierarchy(astNode):
    """Determines the loop structure around a given AST node.
    :param astNode: the AST node
    :return: list of coordinate ids, where the first list entry is the innermost loop
    """
819
    result = []
Martin Bauer's avatar
Martin Bauer committed
820
    node = astNode
821
822
823
824
    while node is not None:
        node = getNextParentOfType(node, ast.LoopOverCoordinate)
        if node:
            result.append(node.coordinateToLoopOver)
825
826
    return reversed(result)

Jan Hoenig's avatar
Jan Hoenig committed
827

Jan Hoenig's avatar
Jan Hoenig committed
828
829
830
831
832
833
834
835
def get_type(node):
    if isinstance(node, ast.Indexed):
        return node.args[0].dtype
    elif isinstance(node, ast.Node):
        return node.dtype
    # TODO sp.NumberSymbol
    elif isinstance(node, sp.Number):
        if isinstance(node, sp.Float):
836
            return createType('double')
Jan Hoenig's avatar
Jan Hoenig committed
837
        elif isinstance(node, sp.Integer):
838
            return createType('int')
Jan Hoenig's avatar
Jan Hoenig committed
839
840
841
842
843
844
        else:
            raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))
    else:
        raise NotImplemented('Not yet supported: %s %s' % (node, type(node)))


Jan Hoenig's avatar
Jan Hoenig committed
845
def insert_casts(node):
Jan Hoenig's avatar
Jan Hoenig committed
846
    """
Martin Bauer's avatar
Martin Bauer committed
847
    Inserts casts and dtype whpere needed
Jan Hoenig's avatar
Jan Hoenig committed
848
849
850
    :param node: ast which should be traversed
    :return: node
    """
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    def conversion(args):
        target = args[0]
        if isinstance(target.dtype, PointerType):
            # Pointer arithmetic
            for arg in args[1:]:
                # Check validness
                if not arg.dtype.is_int() and not arg.dtype.is_uint():
                    raise ValueError("Impossible pointer arithmetic", target, arg)
            pointer = ast.PointerArithmetic(ast.Add(args[1:]), target)
            return [pointer]

        else:
            for i in range(len(args)):
                if args[i].dtype != target.dtype:
                    args[i] = ast.Conversion(args[i], target.dtype, node)
            return args
Jan Hoenig's avatar
Jan Hoenig committed
867
868
869
870

    for arg in node.args:
        insert_casts(arg)
    if isinstance(node, ast.Indexed):
Jan Hoenig's avatar
Jan Hoenig committed
871
        #TODO revmove this
Jan Hoenig's avatar
Jan Hoenig committed
872
        pass
Jan Hoenig's avatar
Jan Hoenig committed
873
    elif isinstance(node, ast.Expr):
874
875
876
        #print(node, node.args)
        #print([type(arg) for arg in node.args])
        #print([arg.dtype for arg in node.args])
Jan Hoenig's avatar
Jan Hoenig committed
877
        args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
Jan Hoenig's avatar
Jan Hoenig committed
878
        target = args[0]
879
        node.args = conversion(args)
Jan Hoenig's avatar
Jan Hoenig committed
880
        node.dtype = target.dtype
881
882
        #print(node.dtype)
        #print(node)
Jan Hoenig's avatar
Jan Hoenig committed
883
884
885
    elif isinstance(node, ast.SympyAssignment):
        if node.lhs.dtype != node.rhs.dtype:
            node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
Jan Hoenig's avatar
Jan Hoenig committed
886
    elif isinstance(node, ast.LoopOverCoordinate):
Jan Hoenig's avatar
Jan Hoenig committed
887
        pass
Jan Hoenig's avatar
Jan Hoenig committed
888
    return node
Jan Hoenig's avatar
Jan Hoenig committed
889
890
891


def desympy_ast(node):
Jan Hoenig's avatar
Jan Hoenig committed
892
893
894
895
896
897
    """
    Remove Sympy Expressions, which have more then one argument.
    This is necessary for further changes in the tree.
    :param node: ast which should be traversed. Only node's children will be modified.
    :return: (modified) node
    """
Jan Hoenig's avatar
Jan Hoenig committed
898
899
    if node.args is None:
        return node
Jan Hoenig's avatar
Jan Hoenig committed
900
901
902
903
    for i in range(len(node.args)):
        arg = node.args[i]
        if isinstance(arg, sp.Add):
            node.replace(arg, ast.Add(arg.args, node))
Jan Hoenig's avatar
Jan Hoenig committed
904
905
        elif isinstance(arg, sp.Number):
            node.replace(arg, ast.Number(arg, node))
Jan Hoenig's avatar
Jan Hoenig committed
906
907
908
909
        elif isinstance(arg, sp.Mul):
            node.replace(arg, ast.Mul(arg.args, node))
        elif isinstance(arg, sp.Pow):
            node.replace(arg, ast.Pow(arg.args, node))
910
911
912
913
914
915
916
        elif isinstance(arg, sp.tensor.Indexed) or isinstance(arg, sp.tensor.indexed.Indexed):
            node.replace(arg, ast.Indexed(arg.args, arg.base, node))
        elif isinstance(arg,  sp.tensor.IndexedBase):
            node.replace(arg, arg.label)
        #elif isinstance(arg, sp.containers.Tuple):
        #
        else:
917
918
            #print('Not transforming:', type(arg), arg)
            pass
Jan Hoenig's avatar
Jan Hoenig committed
919
920
    for arg in node.args:
        desympy_ast(arg)
Jan Hoenig's avatar
Jan Hoenig committed
921
    return node
922
923
924
925
926
927
928
929
930
931
932
933


def check_dtype(node):
    if isinstance(node, ast.KernelFunction):
        pass
    elif isinstance(node, ast.Block):
        pass
    elif isinstance(node, ast.LoopOverCoordinate):
        pass
    elif isinstance(node, ast.SympyAssignment):
        pass
    else:
934
935
936
        #print(node)
        #print(node.dtype)
        pass
937
938
939
    for arg in node.args:
        check_dtype(arg)