Skip to content
Snippets Groups Projects
Commit a2080a92 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfix in resolveFieldIndex - base pointer are now inserted in correct order

parent d6cdaadf
No related merge requests found
......@@ -137,6 +137,7 @@ class Block(Node):
self._nodes.insert(0, node)
def insertBefore(self, newNode, insertBefore):
newNode.parent = self
idx = self._nodes.index(insertBefore)
self._nodes.insert(idx, newNode)
......
......@@ -193,7 +193,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
counters to index the field these symbols are used as coordinates
:return: transformed AST
"""
def visitSympyExpr(expr, enclosingBlock):
def visitSympyExpr(expr, enclosingBlock, sympyAssignment):
if isinstance(expr, Field.Access):
fieldAccess = expr
field = fieldAccess.field
......@@ -227,7 +227,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
coordDict = createCoordinateDict(group)
newPtr, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
if newPtr not in enclosingBlock.symbolsDefined:
enclosingBlock.insertFront(ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False))
newAssignment = ast.SympyAssignment(newPtr, lastPointer + offset, isConst=False)
enclosingBlock.insertBefore(newAssignment, sympyAssignment)
lastPointer = newPtr
_, offset = createIntermediateBasePointer(fieldAccess, createCoordinateDict(basePointerInfo[0]),
......@@ -235,7 +236,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
baseArr = IndexedBase(lastPointer, shape=(1,))
return baseArr[offset]
else:
newArgs = [visitSympyExpr(e, enclosingBlock) for e in expr.args]
newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) is sp.Add or type(expr) is sp.Mul else {}
return expr.func(*newArgs, **kwargs) if newArgs else expr
......@@ -243,8 +244,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
if isinstance(subAst, ast.SympyAssignment):
enclosingBlock = subAst.parent
assert type(enclosingBlock) is ast.Block
subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock)
subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock)
subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst)
subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst)
else:
for i, a in enumerate(subAst.args):
visitNode(a)
......@@ -470,7 +471,7 @@ def getOptimalLoopOrdering(fields):
layouts = set([field.layout for field in fields])
if len(layouts) > 1:
raise ValueError("Due to different layout of the fields no optimal loop ordering exists")
raise ValueError("Due to different layout of the fields no optimal loop ordering exists " + str(layouts))
layout = list(layouts)[0]
return list(reversed(layout))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment