Skip to content
Snippets Groups Projects
Commit c744ab1d authored by Martin Bauer's avatar Martin Bauer
Browse files

Added accessed fields to Ast Function node

parent 05c0275c
No related merge requests found
......@@ -69,13 +69,14 @@ class KernelFunction(Node):
def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name)
def __init__(self, body, functionName="kernel"):
def __init__(self, body, fieldsAccessed, functionName="kernel"):
super(KernelFunction, self).__init__()
self._body = body
body.parent = self
self._parameters = None
self._functionName = functionName
self._body.parent = self
self._fieldsAccessed = fieldsAccessed
# these variables are assumed to be global, so no automatic parameter is generated for them
self.globalVariables = set()
......@@ -104,6 +105,11 @@ class KernelFunction(Node):
def functionName(self):
return self._functionName
@property
def fieldsAccessed(self):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return self._fieldsAccessed
def _updateParameters(self):
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
......@@ -325,7 +331,6 @@ class SympyAssignment(Node):
return set()
return set([self._lhsSymbol])
@property
def undefinedSymbols(self):
result = self.rhs.atoms(sp.Symbol)
......
......@@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
code = KernelFunction(Block(assignments), functionName)
code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName)
code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
fieldAccesses = code.atoms(Field.Access)
......
......@@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
sp.sympify(sliceComponent))
currentBody.insertFront(assignment)
return ast.KernelFunction(currentBody, functionName)
return ast.KernelFunction(currentBody, fields, functionName)
def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment