diff --git a/ast.py b/ast.py index 4669f3b3efa34f83d46c9dd814aa52e5dd3370a5..2ed8732a9f5dd5d78dbde6270da47ccdb45e437a 100644 --- a/ast.py +++ b/ast.py @@ -17,13 +17,13 @@ class Node(object): @property def symbolsDefined(self): - """Set of symbols which are defined in this node or its children""" + """Set of symbols which are defined by this node. """ return set() @property - def symbolsRead(self): - """Set of symbols which are accessed/read in this node or its children""" - return set() + def undefinedSymbols(self): + """Symbols which are use but are not defined inside this node""" + raise NotImplementedError() def atoms(self, argType): """ @@ -78,14 +78,15 @@ class KernelFunction(Node): self._parameters = None self._functionName = functionName self._body.parent = self - self.variablesToIgnore = set() + # these variables are assumed to be global, so no automatic parameter is generated for them + self.globalVariables = set() @property def symbolsDefined(self): return set() @property - def symbolsRead(self): + def undefinedSymbols(self): return set() @property @@ -106,7 +107,7 @@ class KernelFunction(Node): return self._functionName def _updateParameters(self): - undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore + undefinedSymbols = self._body.undefinedSymbols - self.globalVariables self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols] self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument, l.isFieldStrideArgument, l.name), @@ -169,11 +170,13 @@ class Block(Node): return result @property - def symbolsRead(self): + def undefinedSymbols(self): result = set() + definedSymbols = set() for a in self.args: - result.update(a.symbolsRead) - return result + result.update(a.undefinedSymbols) + definedSymbols.update(a.symbolsDefined) + return result - definedSymbols def children(self): yield self._nodes @@ -235,9 +238,15 @@ class LoopOverCoordinate(Node): @property def symbolsDefined(self): - result = self._body.symbolsDefined - result.add(self.loopCounterSymbol) - return result + return set([self.loopCounterSymbol]) + + @property + def undefinedSymbols(self): + result = self._body.undefinedSymbols + for possibleSymbol in [self._begin, self._end, self._increment]: + if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic): + result.update(possibleSymbol.atoms(sp.Symbol)) + return result - set([self.loopCounterSymbol]) @staticmethod def getLoopCounterName(coordinateToLoopOver): @@ -255,15 +264,6 @@ class LoopOverCoordinate(Node): def loopCounterSymbol(self): return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver) - @property - def symbolsRead(self): - loopBoundSymbols = set() - for possibleSymbol in [self._begin, self._end, self._increment]: - if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic): - loopBoundSymbols.update(possibleSymbol.atoms(sp.Symbol)) - result = self._body.symbolsRead.union(loopBoundSymbols) - return result - @property def isOutermostLoop(self): from pystencils.transformations import getNextParentOfType @@ -316,8 +316,9 @@ class SympyAssignment(Node): return set() return set([self._lhsSymbol]) + @property - def symbolsRead(self): + def undefinedSymbols(self): result = self.rhs.atoms(sp.Symbol) result.update(self._lhsSymbol.atoms(sp.Symbol)) return result @@ -344,8 +345,11 @@ class TemporaryMemoryAllocation(Node): return set([self.symbol]) @property - def symbolsRead(self): - return set() + def undefinedSymbols(self): + if isinstance(self.size, sp.Basic): + return self.size.atoms(sp.Symbol) + else: + return set() @property def args(self): @@ -361,7 +365,7 @@ class TemporaryMemoryFree(Node): return set() @property - def symbolsRead(self): + def undefinedSymbols(self): return set() @property diff --git a/backends/cbackend.py b/backends/cbackend.py index 314c301bfd84d45e75f7657bdb6518afdfad6dc2..96c9a8584b93a83514e38018ccd747cdd57e1e15 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -37,8 +37,8 @@ class CustomCppCode(Node): return self._symbolsDefined @property - def symbolsRead(self): - return self._symbolsRead + def undefinedSymbols(self): + return self.symbolsDefined - self._symbolsRead class PrintNode(CustomCppCode): diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index 60676058b86e19ecfbc407c2fffb636465c583a3..a0905a82eba1d53f86610d814fc49d2dcd7c8955 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -36,7 +36,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) code = KernelFunction(Block(assignments), functionName) - code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX) + code.globalVariables.update(BLOCK_IDX + THREAD_IDX) fieldAccesses = code.atoms(Field.Access) requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses]) diff --git a/transformations.py b/transformations.py index db2e2e5c46d49c94fac3fd6921df93f874b7e8ec..73c8e31a697eec6db14a781dd29190ac0d942812 100644 --- a/transformations.py +++ b/transformations.py @@ -278,7 +278,7 @@ def moveConstantsBeforeLoop(astNode): if isinstance(element, ast.Block): lastBlock = element lastBlockChild = prevElement - if node.symbolsRead.intersection(element.symbolsDefined): + if node.undefinedSymbols.intersection(element.symbolsDefined): break prevElement = element element = element.parent