Skip to content
Snippets Groups Projects
Commit b207d071 authored by Martin Bauer's avatar Martin Bauer
Browse files

Changed symbolsRead/ symbolsDefined semantics

problem in moveConstantBeforeLoops transformation:

--> a should end up here
{
for() {
  const int a = 5;
}
for() {
  const int a = 5
}
}

the "a" of the lower loop was not moved up, since it could not move across first loop (which is wrong)
parent 87749599
No related merge requests found
...@@ -17,13 +17,13 @@ class Node(object): ...@@ -17,13 +17,13 @@ class Node(object):
@property @property
def symbolsDefined(self): def symbolsDefined(self):
"""Set of symbols which are defined in this node or its children""" """Set of symbols which are defined by this node. """
return set() return set()
@property @property
def symbolsRead(self): def undefinedSymbols(self):
"""Set of symbols which are accessed/read in this node or its children""" """Symbols which are use but are not defined inside this node"""
return set() raise NotImplementedError()
def atoms(self, argType): def atoms(self, argType):
""" """
...@@ -78,14 +78,15 @@ class KernelFunction(Node): ...@@ -78,14 +78,15 @@ class KernelFunction(Node):
self._parameters = None self._parameters = None
self._functionName = functionName self._functionName = functionName
self._body.parent = self self._body.parent = self
self.variablesToIgnore = set() # these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set()
@property @property
def symbolsDefined(self): def symbolsDefined(self):
return set() return set()
@property @property
def symbolsRead(self): def undefinedSymbols(self):
return set() return set()
@property @property
...@@ -106,7 +107,7 @@ class KernelFunction(Node): ...@@ -106,7 +107,7 @@ class KernelFunction(Node):
return self._functionName return self._functionName
def _updateParameters(self): def _updateParameters(self):
undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols] self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument, self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
l.isFieldStrideArgument, l.name), l.isFieldStrideArgument, l.name),
...@@ -169,11 +170,13 @@ class Block(Node): ...@@ -169,11 +170,13 @@ class Block(Node):
return result return result
@property @property
def symbolsRead(self): def undefinedSymbols(self):
result = set() result = set()
definedSymbols = set()
for a in self.args: for a in self.args:
result.update(a.symbolsRead) result.update(a.undefinedSymbols)
return result definedSymbols.update(a.symbolsDefined)
return result - definedSymbols
def children(self): def children(self):
yield self._nodes yield self._nodes
...@@ -235,9 +238,15 @@ class LoopOverCoordinate(Node): ...@@ -235,9 +238,15 @@ class LoopOverCoordinate(Node):
@property @property
def symbolsDefined(self): def symbolsDefined(self):
result = self._body.symbolsDefined return set([self.loopCounterSymbol])
result.add(self.loopCounterSymbol)
return result @property
def undefinedSymbols(self):
result = self._body.undefinedSymbols
for possibleSymbol in [self._begin, self._end, self._increment]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
result.update(possibleSymbol.atoms(sp.Symbol))
return result - set([self.loopCounterSymbol])
@staticmethod @staticmethod
def getLoopCounterName(coordinateToLoopOver): def getLoopCounterName(coordinateToLoopOver):
...@@ -255,15 +264,6 @@ class LoopOverCoordinate(Node): ...@@ -255,15 +264,6 @@ class LoopOverCoordinate(Node):
def loopCounterSymbol(self): def loopCounterSymbol(self):
return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver) return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
@property
def symbolsRead(self):
loopBoundSymbols = set()
for possibleSymbol in [self._begin, self._end, self._increment]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
loopBoundSymbols.update(possibleSymbol.atoms(sp.Symbol))
result = self._body.symbolsRead.union(loopBoundSymbols)
return result
@property @property
def isOutermostLoop(self): def isOutermostLoop(self):
from pystencils.transformations import getNextParentOfType from pystencils.transformations import getNextParentOfType
...@@ -316,8 +316,9 @@ class SympyAssignment(Node): ...@@ -316,8 +316,9 @@ class SympyAssignment(Node):
return set() return set()
return set([self._lhsSymbol]) return set([self._lhsSymbol])
@property @property
def symbolsRead(self): def undefinedSymbols(self):
result = self.rhs.atoms(sp.Symbol) result = self.rhs.atoms(sp.Symbol)
result.update(self._lhsSymbol.atoms(sp.Symbol)) result.update(self._lhsSymbol.atoms(sp.Symbol))
return result return result
...@@ -344,8 +345,11 @@ class TemporaryMemoryAllocation(Node): ...@@ -344,8 +345,11 @@ class TemporaryMemoryAllocation(Node):
return set([self.symbol]) return set([self.symbol])
@property @property
def symbolsRead(self): def undefinedSymbols(self):
return set() if isinstance(self.size, sp.Basic):
return self.size.atoms(sp.Symbol)
else:
return set()
@property @property
def args(self): def args(self):
...@@ -361,7 +365,7 @@ class TemporaryMemoryFree(Node): ...@@ -361,7 +365,7 @@ class TemporaryMemoryFree(Node):
return set() return set()
@property @property
def symbolsRead(self): def undefinedSymbols(self):
return set() return set()
@property @property
......
...@@ -37,8 +37,8 @@ class CustomCppCode(Node): ...@@ -37,8 +37,8 @@ class CustomCppCode(Node):
return self._symbolsDefined return self._symbolsDefined
@property @property
def symbolsRead(self): def undefinedSymbols(self):
return self._symbolsRead return self.symbolsDefined - self._symbolsRead
class PrintNode(CustomCppCode): class PrintNode(CustomCppCode):
......
...@@ -36,7 +36,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau ...@@ -36,7 +36,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
code = KernelFunction(Block(assignments), functionName) code = KernelFunction(Block(assignments), functionName)
code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX) code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
fieldAccesses = code.atoms(Field.Access) fieldAccesses = code.atoms(Field.Access)
requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses]) requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
......
...@@ -278,7 +278,7 @@ def moveConstantsBeforeLoop(astNode): ...@@ -278,7 +278,7 @@ def moveConstantsBeforeLoop(astNode):
if isinstance(element, ast.Block): if isinstance(element, ast.Block):
lastBlock = element lastBlock = element
lastBlockChild = prevElement lastBlockChild = prevElement
if node.symbolsRead.intersection(element.symbolsDefined): if node.undefinedSymbols.intersection(element.symbolsDefined):
break break
prevElement = element prevElement = element
element = element.parent element = element.parent
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment