diff --git a/ast.py b/ast.py index b3936036263f4aaede689a552b94d679665a2902..9f027ac1dd8ad0e9f1a0670e1320dfbd040b8b2f 100644 --- a/ast.py +++ b/ast.py @@ -137,6 +137,7 @@ class Block(Node): self._nodes.insert(0, node) def insertBefore(self, newNode, insertBefore): + newNode.parent = self idx = self._nodes.index(insertBefore) self._nodes.insert(idx, newNode) diff --git a/transformations.py b/transformations.py index a48487798212b35515950bd24793d9cb24512c67..2b47cbde3b7b742759455c365c236168b9ccecc6 100644 --- a/transformations.py +++ b/transformations.py @@ -193,7 +193,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn counters to index the field these symbols are used as coordinates :return: transformed AST """ - def visitSympyExpr(expr, enclosingBlock): + def visitSympyExpr(expr, enclosingBlock, sympyAssignment): if isinstance(expr, Field.Access): fieldAccess = expr field = fieldAccess.field @@ -227,7 +227,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn coordDict = createCoordinateDict(group) newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) if newPtr not in enclosingBlock.symbolsDefined: - enclosingBlock.insertFront(ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)) + newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False) + enclosingBlock.insertBefore(newAssignment, sympyAssignment) lastPointer = newPtr _, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]), @@ -235,7 +236,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn baseArr = IndexedBase(lastPointer, shape=(1,)) return baseArr[offset] else: - newArgs = [visitSympyExpr(e, enclosingBlock) for e in expr.args] + newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args] kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {} return expr.func(*newArgs, **kwargs) if newArgs else expr @@ -243,8 +244,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn if isinstance(subAst, ast.SympyAssignment): enclosingBlock = subAst.parent assert type(enclosingBlock) is ast.Block - subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock) - subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock) + subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst) + subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst) else: for i, a in enumerate(subAst.args): visitNode(a) @@ -470,7 +471,7 @@ def getOptimalLoopOrdering(fields): layouts = set([field.layout for field in fields]) if len(layouts) > 1: - raise ValueError("Due to different layout of the fields no optimal loop ordering exists") + raise ValueError("Due to different layout of the fields no optimal loop ordering exists " + str(layouts)) layout = list(layouts)[0] return list(reversed(layout))