Skip to content
Snippets Groups Projects
Commit c744ab1d authored by Martin Bauer's avatar Martin Bauer
Browse files

Added accessed fields to Ast Function node

parent 05c0275c
Branches
Tags
No related merge requests found
...@@ -69,13 +69,14 @@ class KernelFunction(Node): ...@@ -69,13 +69,14 @@ class KernelFunction(Node):
def __repr__(self): def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name) return '<{0} {1}>'.format(self.dtype, self.name)
def __init__(self, body, functionName="kernel"): def __init__(self, body, fieldsAccessed, functionName="kernel"):
super(KernelFunction, self).__init__() super(KernelFunction, self).__init__()
self._body = body self._body = body
body.parent = self body.parent = self
self._parameters = None self._parameters = None
self._functionName = functionName self._functionName = functionName
self._body.parent = self self._body.parent = self
self._fieldsAccessed = fieldsAccessed
# these variables are assumed to be global, so no automatic parameter is generated for them # these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set() self.globalVariables = set()
...@@ -104,6 +105,11 @@ class KernelFunction(Node): ...@@ -104,6 +105,11 @@ class KernelFunction(Node):
def functionName(self): def functionName(self):
return self._functionName return self._functionName
@property
def fieldsAccessed(self):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return self._fieldsAccessed
def _updateParameters(self): def _updateParameters(self):
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols] self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
...@@ -325,7 +331,6 @@ class SympyAssignment(Node): ...@@ -325,7 +331,6 @@ class SympyAssignment(Node):
return set() return set()
return set([self._lhsSymbol]) return set([self._lhsSymbol])
@property @property
def undefinedSymbols(self): def undefinedSymbols(self):
result = self.rhs.atoms(sp.Symbol) result = self.rhs.atoms(sp.Symbol)
......
...@@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau ...@@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
code = KernelFunction(Block(assignments), functionName) code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName)
code.globalVariables.update(BLOCK_IDX + THREAD_IDX) code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
fieldAccesses = code.atoms(Field.Access) fieldAccesses = code.atoms(Field.Access)
......
...@@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None ...@@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
sp.sympify(sliceComponent)) sp.sympify(sliceComponent))
currentBody.insertFront(assignment) currentBody.insertFront(assignment)
return ast.KernelFunction(currentBody, functionName) return ast.KernelFunction(currentBody, fields, functionName)
def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment