transformations.py 21.1 KB
Newer Older
1
2
3
4
from collections import defaultdict
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase
Martin Bauer's avatar
Martin Bauer committed
5

6
7
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.typedsymbol import TypedSymbol
Martin Bauer's avatar
Martin Bauer committed
8
from pystencils.slicing import normalizeSlice
9
10
11
import pystencils.ast as ast


12
def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None):
13
    """
Martin Bauer's avatar
Martin Bauer committed
14
    Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
15
16
    :param body: list of nodes
    :param functionName: name of generated C function
Martin Bauer's avatar
Martin Bauer committed
17
    :param iterationSlice: if not None, iteration is done only over this slice of the field
18
19
20
    :param ghostLayers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
                        if None, the number of ghost layers is determined automatically and assumed to be equal for a
                        all dimensions
Martin Bauer's avatar
Martin Bauer committed
21
    :return: :class:`LoopOverCoordinate` instance with nested loops, ordered according to field layouts
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    """
    # find correct ordering by inspecting participating FieldAccesses
    fieldAccesses = body.atoms(Field.Access)
    fieldList = [e.field for e in fieldAccesses]
    fields = set(fieldList)
    loopOrder = getOptimalLoopOrdering(fields)

    shapes = set([f.spatialShape for f in fields])

    if len(shapes) > 1:
        nrOfFixedSizedFields = 0
        for shape in shapes:
            if not isinstance(shape[0], sp.Basic):
                nrOfFixedSizedFields += 1
        assert nrOfFixedSizedFields <= 1, "Differently sized field accesses in loop body: " + str(shapes)
    shape = list(shapes)[0]

Martin Bauer's avatar
Martin Bauer committed
39
40
41
    if iterationSlice is not None:
        iterationSlice = normalizeSlice(iterationSlice, shape)

42
43
44
45
    if ghostLayers is None:
        requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
        ghostLayers = [(requiredGhostLayers, requiredGhostLayers)] * len(loopOrder)

46
47
48
    currentBody = body
    lastLoop = None
    for i, loopCoordinate in enumerate(loopOrder):
Martin Bauer's avatar
Martin Bauer committed
49
        if iterationSlice is None:
50
51
            begin = ghostLayers[loopCoordinate][0]
            end = shape[loopCoordinate] - ghostLayers[loopCoordinate][1]
Martin Bauer's avatar
Martin Bauer committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
            newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1)
            lastLoop = newLoop
            currentBody = ast.Block([lastLoop])
        else:
            sliceComponent = iterationSlice[loopCoordinate]
            if type(sliceComponent) is slice:
                sc = sliceComponent
                newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step)
                lastLoop = newLoop
                currentBody = ast.Block([lastLoop])
            else:
                assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
                                                 sp.sympify(sliceComponent))
                currentBody.insertFront(assignment)
66
    return ast.KernelFunction(currentBody, fields, functionName)
67
68
69


def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
Martin Bauer's avatar
Martin Bauer committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    r"""
    Addressing elements in structured arrays are done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
    where :math:`c_i` is the coordinate value and :math:`s_i` the stride of a coordinate.
    The sum can be split up into multiple parts, such that parts of it can be pulled before loops.
    This function creates such an access for coordinates :math:`i \in \mbox{coordinates}`.
    Returns a new typed symbol, where the name encodes which coordinates have been resolved.
    :param fieldAccess: instance of :class:`pystencils.field.Field.Access` which provides strides and offsets
    :param coordinates: mapping of coordinate ids to its value, where stride*value is calculated
    :param previousPtr: the pointer which is dereferenced
    :return: tuple with the new pointer symbol and the calculated offset

    Example:
        >>> field = Field.createGeneric('myfield', spatialDimensions=2, indexDimensions=1)
        >>> x, y = sp.symbols("x y")
        >>> prevPointer = TypedSymbol("ptr", "double")
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x}, prevPointer)
        (ptr_E, x*fstride_myfield[0] + fstride_myfield[0])
        >>> createIntermediateBasePointer(field[1,-2](5), {0: x, 1 : y }, prevPointer)
        (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1])
    """
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    field = fieldAccess.field

    offset = 0
    name = ""
    listToHash = []
    for coordinateId, coordinateValue in coordinates.items():
        offset += field.strides[coordinateId] * coordinateValue

        if coordinateId < field.spatialDimensions:
            offset += field.strides[coordinateId] * fieldAccess.offsets[coordinateId]
            if type(fieldAccess.offsets[coordinateId]) is int:
                offsetComp = offsetComponentToDirectionString(coordinateId, fieldAccess.offsets[coordinateId])
                name += "_"
                name += offsetComp if offsetComp else "C"
            else:
                listToHash.append(fieldAccess.offsets[coordinateId])
        else:
            if type(coordinateValue) is int:
                name += "_%d" % (coordinateValue,)
            else:
                listToHash.append(coordinateValue)

    if len(listToHash) > 0:
        name += "%0.6X" % (abs(hash(tuple(listToHash))))

    newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype)
    return newPtr, offset


def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
    """
Martin Bauer's avatar
Martin Bauer committed
121
122
123
124
125
126
127
    Creates base pointer specification for :func:`resolveFieldAccesses` function.

    Specification of how many and which intermediate pointers are created for a field access.
    For example [ (0), (2,3,)]  creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
    zero directly in the field access. These specifications are more sensible defined dependent on the loop ordering.
    This function translates more readable version into the specification above.

128
    Allowed specifications:
Martin Bauer's avatar
Martin Bauer committed
129
130
131
132
133
134
        - "spatialInner<int>" spatialInner0 is the innermost loop coordinate,
          spatialInner1 the loop enclosing the innermost
        - "spatialOuter<int>" spatialOuter0 is the outermost loop
        - "index<int>": index coordinate
        - "<int>": specifying directly the coordinate

135
136
137
    :param basePointerSpecification: nested list with above specifications
    :param loopOrder: list with ordering of loops from inner to outer
    :param field:
Martin Bauer's avatar
Martin Bauer committed
138
    :return: list of tuples that can be passed to :func:`resolveFieldAccesses`
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    """
    result = []
    specifiedCoordinates = set()
    for specGroup in basePointerSpecification:
        newGroup = []

        def addNewElement(i):
            if i >= field.spatialDimensions + field.indexDimensions:
                raise ValueError("Coordinate %d does not exist" % (i,))
            newGroup.append(i)
            if i in specifiedCoordinates:
                raise ValueError("Coordinate %d specified two times" % (i,))
            specifiedCoordinates.add(i)

        for element in specGroup:
            if type(element) is int:
                addNewElement(element)
            elif element.startswith("spatial"):
                element = element[len("spatial"):]
                if element.startswith("Inner"):
                    index = int(element[len("Inner"):])
                    addNewElement(loopOrder[index])
                elif element.startswith("Outer"):
                    index = int(element[len("Outer"):])
                    addNewElement(loopOrder[-index])
                elif element == "all":
                    for i in range(field.spatialDimensions):
                        addNewElement(i)
                else:
                    raise ValueError("Could not parse " + element)
            elif element.startswith("index"):
                index = int(element[len("index"):])
                addNewElement(field.spatialDimensions + index)
            else:
                raise ValueError("Unknown specification %s" % (element,))

        result.append(newGroup)

    allCoordinates = set(range(field.spatialDimensions + field.indexDimensions))
    rest = allCoordinates - specifiedCoordinates
    if rest:
        result.append(list(rest))
    return result


184
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
Martin Bauer's avatar
Martin Bauer committed
185
186
187
188
    """
    Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing

    :param astNode: the AST root
189
    :param readOnlyFieldNames: set of field names which are considered read-only
Martin Bauer's avatar
Martin Bauer committed
190
191
192
193
194
195
    :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created
                                   for details see :func:`parseBasePointerInfo`
    :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
                                    counters to index the field these symbols are used as coordinates
    :return: transformed AST
    """
196
    def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
197
198
199
200
201
202
203
204
205
        if isinstance(expr, Field.Access):
            fieldAccess = expr
            field = fieldAccess.field
            if field.name in fieldToBasePointerInfo:
                basePointerInfo = fieldToBasePointerInfo[field.name]
            else:
                basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]

            dtype = "%s * __restrict__" % field.dtype
206
            if field.name in readOnlyFieldNames:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                dtype = "const " + dtype

            fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype)

            lastPointer = fieldPtr

            def createCoordinateDict(group):
                coordDict = {}
                for e in group:
                    if e < field.spatialDimensions:
                        if field.name in fieldToFixedCoordinates:
                            coordDict[e] = fieldToFixedCoordinates[field.name][e]
                        else:
                            ctrName = ast.LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
                            coordDict[e] = TypedSymbol("%s_%d" % (ctrName, e), "int")
                    else:
                        coordDict[e] = fieldAccess.index[e-field.spatialDimensions]
                return coordDict

            for group in reversed(basePointerInfo[1:]):
                coordDict = createCoordinateDict(group)
                newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
                if newPtr not in enclosingBlock.symbolsDefined:
230
231
                    newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)
                    enclosingBlock.insertBefore(newAssignment, sympyAssignment)
232
233
234
235
236
237
238
                lastPointer = newPtr

            _, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]),
                                                      lastPointer)
            baseArr = IndexedBase(lastPointer, shape=(1,))
            return baseArr[offset]
        else:
239
            newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
240
241
242
243
244
245
246
            kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
            return expr.func(*newArgs, **kwargs) if newArgs else expr

    def visitNode(subAst):
        if isinstance(subAst, ast.SympyAssignment):
            enclosingBlock = subAst.parent
            assert type(enclosingBlock) is ast.Block
247
248
            subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
            subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
249
250
251
252
253
254
255
256
        else:
            for i, a in enumerate(subAst.args):
                visitNode(a)

    return visitNode(astNode)


def moveConstantsBeforeLoop(astNode):
Martin Bauer's avatar
Martin Bauer committed
257
258
259
260
261
262
    """
    Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
    Call this after creating the loop structure with :func:`makeLoopOverDomain`
    :param astNode:
    :return:
    """
263
    def findBlockToMoveTo(node):
Martin Bauer's avatar
Martin Bauer committed
264
265
        """
        Traverses parents of node as long as the symbols are independent and returns a (parent) block
266
        the assignment can be safely moved to
Martin Bauer's avatar
Martin Bauer committed
267
268
269
        :param node: SympyAssignment inside a Block
        :return blockToInsertTo, childOfBlockToInsertBefore
        """
270
271
272
273
        assert isinstance(node, ast.SympyAssignment)
        assert isinstance(node.parent, ast.Block)

        lastBlock = node.parent
Martin Bauer's avatar
Martin Bauer committed
274
        lastBlockChild = node
275
        element = node.parent
Martin Bauer's avatar
Martin Bauer committed
276
        prevElement = node
277
278
279
        while element:
            if isinstance(element, ast.Block):
                lastBlock = element
Martin Bauer's avatar
Martin Bauer committed
280
                lastBlockChild = prevElement
281
            if node.undefinedSymbols.intersection(element.symbolsDefined):
282
                break
Martin Bauer's avatar
Martin Bauer committed
283
            prevElement = element
284
            element = element.parent
Martin Bauer's avatar
Martin Bauer committed
285
        return lastBlock, lastBlockChild
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    def checkIfAssignmentAlreadyInBlock(assignment, targetBlock):
        for arg in targetBlock.args:
            if type(arg) is not ast.SympyAssignment:
                continue
            if arg.lhs == assignment.lhs:
                return arg
        return None

    for block in astNode.atoms(ast.Block):
        children = block.takeChildNodes()
        for child in children:
            if not isinstance(child, ast.SympyAssignment):
                block.append(child)
            else:
Martin Bauer's avatar
Martin Bauer committed
301
                target, childToInsertBefore = findBlockToMoveTo(child)
302
303
304
305
306
                if target == block:     # movement not possible
                    target.append(child)
                else:
                    existingAssignment = checkIfAssignmentAlreadyInBlock(child, target)
                    if not existingAssignment:
Martin Bauer's avatar
Martin Bauer committed
307
                        target.insertBefore(child, childToInsertBefore)
308
309
310
311
312
                    else:
                        assert existingAssignment.rhs == child.rhs, "Symbol with same name exists already"


def splitInnerLoop(astNode, symbolGroups):
Martin Bauer's avatar
Martin Bauer committed
313
314
315
316
    """
    Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
    :param astNode: AST root
    :param symbolGroups: sequence of symbol sequences: for each symbol sequence a new inner loop is created which
317
318
        updates these symbols and their dependent symbols. Symbols which are in none of the symbolGroups and which
        no symbol in a symbol group depends on, are not updated!
Martin Bauer's avatar
Martin Bauer committed
319
320
    :return: transformed AST
    """
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    allLoops = astNode.atoms(ast.LoopOverCoordinate)
    innerLoop = [l for l in allLoops if l.isInnermostLoop]
    assert len(innerLoop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
    innerLoop = innerLoop[0]
    assert type(innerLoop.body) is ast.Block
    outerLoop = [l for l in allLoops if l.isOutermostLoop]
    assert len(outerLoop) == 1, "Error in AST, multiple outermost loops."
    outerLoop = outerLoop[0]

    symbolsWithTemporaryArray = dict()

    assignmentMap = {a.lhs: a for a in innerLoop.body.args}

    assignmentGroups = []
    for symbolGroup in symbolGroups:
        # get all dependent symbols
        symbolsToProcess = list(symbolGroup)
        symbolsResolved = set()
        while symbolsToProcess:
            s = symbolsToProcess.pop()
            if s in symbolsResolved:
                continue

            if s in assignmentMap:  # if there is no assignment inside the loop body it is independent already
                for newSymbol in assignmentMap[s].rhs.atoms(sp.Symbol):
                    if type(newSymbol) is not Field.Access and newSymbol not in symbolsWithTemporaryArray:
                        symbolsToProcess.append(newSymbol)
            symbolsResolved.add(s)

        for symbol in symbolGroup:
            if type(symbol) is not Field.Access:
                assert type(symbol) is TypedSymbol
                symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]

        assignmentGroup = []
        for assignment in innerLoop.body.args:
            if assignment.lhs in symbolsResolved:
                newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
                if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
                    newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
                else:
                    newLhs = assignment.lhs
                assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
        assignmentGroups.append(assignmentGroup)

    newLoops = [innerLoop.newLoopWithDifferentBody(ast.Block(group)) for group in assignmentGroups]
367
    innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
368
369

    for tmpArray in symbolsWithTemporaryArray:
Martin Bauer's avatar
Martin Bauer committed
370
        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
371
372
373
374
        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))


def typeAllEquations(eqs, typeForSymbol):
Martin Bauer's avatar
Martin Bauer committed
375
376
377
378
379
380
381
382
383
    """
    Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
    Additionally returns sets of all fields which are read/written

    :param eqs: list of equations
    :param typeForSymbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
    :return: ``fieldsRead, fieldsWritten, typedEquations`` set of read fields, set of written fields, list of equations
               where symbols have been replaced by typed symbols
    """
384
385
386
387
388
389
390
391
392
393
    fieldsWritten = set()
    fieldsRead = set()

    def processRhs(term):
        """Replaces Symbols by:
            - TypedSymbol if symbol is not a field access
        """
        if isinstance(term, Field.Access):
            fieldsRead.add(term.field)
            return term
394
395
        elif isinstance(term, TypedSymbol):
            return term
396
397
398
399
400
401
402
403
404
405
406
        elif isinstance(term, sp.Symbol):
            return TypedSymbol(term.name, typeForSymbol[term.name])
        else:
            newArgs = [processRhs(arg) for arg in term.args]
            return term.func(*newArgs) if newArgs else term

    def processLhs(term):
        """Replaces symbol by TypedSymbol and adds field to fieldsWriten"""
        if isinstance(term, Field.Access):
            fieldsWritten.add(term.field)
            return term
407
408
        elif isinstance(term, TypedSymbol):
            return term
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        elif isinstance(term, sp.Symbol):
            return TypedSymbol(term.name, typeForSymbol[term.name])
        else:
            assert False, "Expected a symbol as left-hand-side"

    typedEquations = []
    for eq in eqs:
        if isinstance(eq, sp.Eq):
            newLhs = processLhs(eq.lhs)
            newRhs = processRhs(eq.rhs)
            typedEquations.append(ast.SympyAssignment(newLhs, newRhs))
        else:
            assert isinstance(eq, ast.Node), "Only equations and ast nodes are allowed in input"
            typedEquations.append(eq)

    typedEquations = typedEquations

    return fieldsRead, fieldsWritten, typedEquations


Martin Bauer's avatar
Martin Bauer committed
429
430
431
# --------------------------------------- Helper Functions -------------------------------------------------------------


432
def typingFromSympyInspection(eqs, defaultType="double"):
Martin Bauer's avatar
Martin Bauer committed
433
434
435
436
437
438
439
    """
    Creates a default symbol name to type mapping.
    If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
    :param eqs: list of equations
    :param defaultType: the type for non-boolean symbols
    :return: dictionary, mapping symbol name to type
    """
440
441
442
443
444
445
446
447
    result = defaultdict(lambda: defaultType)
    for eq in eqs:
        if isinstance(eq.rhs, Boolean):
            result[eq.lhs.name] = "bool"
    return result


def getNextParentOfType(node, parentType):
Martin Bauer's avatar
Martin Bauer committed
448
449
450
    """
    Traverses the AST nodes parents until a parent of given type was found. If no such parent is found, None is returned
    """
451
452
453
454
455
456
457
458
459
    parent = node.parent
    while parent is not None:
        if isinstance(parent, parentType):
            return parent
        parent = parent.parent
    return None


def getOptimalLoopOrdering(fields):
Martin Bauer's avatar
Martin Bauer committed
460
461
462
463
464
465
    """
    Determines the optimal loop order for a given set of fields.
    If the fields have different memory layout or different sizes an exception is thrown.
    :param fields: sequence of fields
    :return: list of coordinate ids, where the first list entry should be the innermost loop
    """
466
467
468
469
470
471
472
473
    assert len(fields) > 0
    refField = next(iter(fields))
    for field in fields:
        if field.spatialDimensions != refField.spatialDimensions:
            raise ValueError("All fields have to have the same number of spatial dimensions")

    layouts = set([field.layout for field in fields])
    if len(layouts) > 1:
474
        raise ValueError("Due to different layout of the fields no optimal loop ordering exists " + str(layouts))
475
476
477
478
    layout = list(layouts)[0]
    return list(reversed(layout))


Martin Bauer's avatar
Martin Bauer committed
479
480
481
482
483
def getLoopHierarchy(astNode):
    """Determines the loop structure around a given AST node.
    :param astNode: the AST node
    :return: list of coordinate ids, where the first list entry is the innermost loop
    """
484
    result = []
Martin Bauer's avatar
Martin Bauer committed
485
    node = astNode
486
487
488
489
490
    while node is not None:
        node = getNextParentOfType(node, ast.LoopOverCoordinate)
        if node:
            result.append(node.coordinateToLoopOver)
    return result