Commit c744ab1d authored by Martin Bauer's avatar Martin Bauer
Browse files

Added accessed fields to Ast Function node

parent 05c0275c
......@@ -69,13 +69,14 @@ class KernelFunction(Node):
def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name)
def __init__(self, body, functionName="kernel"):
def __init__(self, body, fieldsAccessed, functionName="kernel"):
super(KernelFunction, self).__init__()
self._body = body
body.parent = self
self._parameters = None
self._functionName = functionName
self._body.parent = self
self._fieldsAccessed = fieldsAccessed
# these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set()
......@@ -104,6 +105,11 @@ class KernelFunction(Node):
def functionName(self):
return self._functionName
@property
def fieldsAccessed(self):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return self._fieldsAccessed
def _updateParameters(self):
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
......@@ -325,7 +331,6 @@ class SympyAssignment(Node):
return set()
return set([self._lhsSymbol])
@property
def undefinedSymbols(self):
result = self.rhs.atoms(sp.Symbol)
......
......@@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
code = KernelFunction(Block(assignments), functionName)
code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName)
code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
fieldAccesses = code.atoms(Field.Access)
......
......@@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
sp.sympify(sliceComponent))
currentBody.insertFront(assignment)
return ast.KernelFunction(currentBody, functionName)
return ast.KernelFunction(currentBody, fields, functionName)
def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment