diff --git a/ast.py b/ast.py index 8c8fc9f8bda28a52b3af77eef2d087f459863266..d5aa7748281d704c5142e3448742f338ac525a4a 100644 --- a/ast.py +++ b/ast.py @@ -69,13 +69,14 @@ class KernelFunction(Node): def __repr__(self): return '<{0} {1}>'.format(self.dtype, self.name) - def __init__(self, body, functionName="kernel"): + def __init__(self, body, fieldsAccessed, functionName="kernel"): super(KernelFunction, self).__init__() self._body = body body.parent = self self._parameters = None self._functionName = functionName self._body.parent = self + self._fieldsAccessed = fieldsAccessed # these variables are assumed to be global, so no automatic parameter is generated for them self.globalVariables = set() @@ -104,6 +105,11 @@ class KernelFunction(Node): def functionName(self): return self._functionName + @property + def fieldsAccessed(self): + """Set of Field instances: fields which are accessed inside this kernel function""" + return self._fieldsAccessed + def _updateParameters(self): undefinedSymbols = self._body.undefinedSymbols - self.globalVariables self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols] @@ -325,7 +331,6 @@ class SympyAssignment(Node): return set() return set([self._lhsSymbol]) - @property def undefinedSymbols(self): result = self.rhs.atoms(sp.Symbol) diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index a0905a82eba1d53f86610d814fc49d2dcd7c8955..2313e518d9e53f6cf97f51f755b0d74e4e9a2e08 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) - code = KernelFunction(Block(assignments), functionName) + code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName) code.globalVariables.update(BLOCK_IDX + THREAD_IDX) fieldAccesses = code.atoms(Field.Access) diff --git a/transformations.py b/transformations.py index 73c8e31a697eec6db14a781dd29190ac0d942812..6f92f08e2fffdcdf4044313b977da642575bc9e7 100644 --- a/transformations.py +++ b/transformations.py @@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), sp.sympify(sliceComponent)) currentBody.insertFront(assignment) - return ast.KernelFunction(currentBody, functionName) + return ast.KernelFunction(currentBody, fields, functionName) def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):