Skip to content
Snippets Groups Projects
Commit b207d071 authored by Martin Bauer's avatar Martin Bauer
Browse files

Changed symbolsRead/ symbolsDefined semantics

problem in moveConstantBeforeLoops transformation:

--> a should end up here
{
for() {
  const int a = 5;
}
for() {
  const int a = 5
}
}

the "a" of the lower loop was not moved up, since it could not move across first loop (which is wrong)
parent 87749599
No related merge requests found
......@@ -17,13 +17,13 @@ class Node(object):
@property
def symbolsDefined(self):
"""Set of symbols which are defined in this node or its children"""
"""Set of symbols which are defined by this node. """
return set()
@property
def symbolsRead(self):
"""Set of symbols which are accessed/read in this node or its children"""
return set()
def undefinedSymbols(self):
"""Symbols which are use but are not defined inside this node"""
raise NotImplementedError()
def atoms(self, argType):
"""
......@@ -78,14 +78,15 @@ class KernelFunction(Node):
self._parameters = None
self._functionName = functionName
self._body.parent = self
self.variablesToIgnore = set()
# these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set()
@property
def symbolsDefined(self):
return set()
@property
def symbolsRead(self):
def undefinedSymbols(self):
return set()
@property
......@@ -106,7 +107,7 @@ class KernelFunction(Node):
return self._functionName
def _updateParameters(self):
undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
l.isFieldStrideArgument, l.name),
......@@ -169,11 +170,13 @@ class Block(Node):
return result
@property
def symbolsRead(self):
def undefinedSymbols(self):
result = set()
definedSymbols = set()
for a in self.args:
result.update(a.symbolsRead)
return result
result.update(a.undefinedSymbols)
definedSymbols.update(a.symbolsDefined)
return result - definedSymbols
def children(self):
yield self._nodes
......@@ -235,9 +238,15 @@ class LoopOverCoordinate(Node):
@property
def symbolsDefined(self):
result = self._body.symbolsDefined
result.add(self.loopCounterSymbol)
return result
return set([self.loopCounterSymbol])
@property
def undefinedSymbols(self):
result = self._body.undefinedSymbols
for possibleSymbol in [self._begin, self._end, self._increment]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
result.update(possibleSymbol.atoms(sp.Symbol))
return result - set([self.loopCounterSymbol])
@staticmethod
def getLoopCounterName(coordinateToLoopOver):
......@@ -255,15 +264,6 @@ class LoopOverCoordinate(Node):
def loopCounterSymbol(self):
return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
@property
def symbolsRead(self):
loopBoundSymbols = set()
for possibleSymbol in [self._begin, self._end, self._increment]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
loopBoundSymbols.update(possibleSymbol.atoms(sp.Symbol))
result = self._body.symbolsRead.union(loopBoundSymbols)
return result
@property
def isOutermostLoop(self):
from pystencils.transformations import getNextParentOfType
......@@ -316,8 +316,9 @@ class SympyAssignment(Node):
return set()
return set([self._lhsSymbol])
@property
def symbolsRead(self):
def undefinedSymbols(self):
result = self.rhs.atoms(sp.Symbol)
result.update(self._lhsSymbol.atoms(sp.Symbol))
return result
......@@ -344,8 +345,11 @@ class TemporaryMemoryAllocation(Node):
return set([self.symbol])
@property
def symbolsRead(self):
return set()
def undefinedSymbols(self):
if isinstance(self.size, sp.Basic):
return self.size.atoms(sp.Symbol)
else:
return set()
@property
def args(self):
......@@ -361,7 +365,7 @@ class TemporaryMemoryFree(Node):
return set()
@property
def symbolsRead(self):
def undefinedSymbols(self):
return set()
@property
......
......@@ -37,8 +37,8 @@ class CustomCppCode(Node):
return self._symbolsDefined
@property
def symbolsRead(self):
return self._symbolsRead
def undefinedSymbols(self):
return self.symbolsDefined - self._symbolsRead
class PrintNode(CustomCppCode):
......
......@@ -36,7 +36,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
code = KernelFunction(Block(assignments), functionName)
code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX)
code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
fieldAccesses = code.atoms(Field.Access)
requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
......
......@@ -278,7 +278,7 @@ def moveConstantsBeforeLoop(astNode):
if isinstance(element, ast.Block):
lastBlock = element
lastBlockChild = prevElement
if node.symbolsRead.intersection(element.symbolsDefined):
if node.undefinedSymbols.intersection(element.symbolsDefined):
break
prevElement = element
element = element.parent
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment