From c744ab1d36051acbeb7e9c08d3a20a4f29a03939 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 15 Nov 2016 08:16:50 +0100 Subject: [PATCH] Added accessed fields to Ast Function node --- ast.py | 9 +++++++-- gpucuda/kernelcreation.py | 2 +- transformations.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ast.py b/ast.py index 8c8fc9f8b..d5aa77482 100644 --- a/ast.py +++ b/ast.py @@ -69,13 +69,14 @@ class KernelFunction(Node): def __repr__(self): return '<{0} {1}>'.format(self.dtype, self.name) - def __init__(self, body, functionName="kernel"): + def __init__(self, body, fieldsAccessed, functionName="kernel"): super(KernelFunction, self).__init__() self._body = body body.parent = self self._parameters = None self._functionName = functionName self._body.parent = self + self._fieldsAccessed = fieldsAccessed # these variables are assumed to be global, so no automatic parameter is generated for them self.globalVariables = set() @@ -104,6 +105,11 @@ class KernelFunction(Node): def functionName(self): return self._functionName + @property + def fieldsAccessed(self): + """Set of Field instances: fields which are accessed inside this kernel function""" + return self._fieldsAccessed + def _updateParameters(self): undefinedSymbols = self._body.undefinedSymbols - self.globalVariables self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols] @@ -325,7 +331,6 @@ class SympyAssignment(Node): return set() return set([self._lhsSymbol]) - @property def undefinedSymbols(self): result = self.rhs.atoms(sp.Symbol) diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index a0905a82e..2313e518d 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) - code = KernelFunction(Block(assignments), functionName) + code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName) code.globalVariables.update(BLOCK_IDX + THREAD_IDX) fieldAccesses = code.atoms(Field.Access) diff --git a/transformations.py b/transformations.py index 73c8e31a6..6f92f08e2 100644 --- a/transformations.py +++ b/transformations.py @@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), sp.sympify(sliceComponent)) currentBody.insertFront(assignment) - return ast.KernelFunction(currentBody, functionName) + return ast.KernelFunction(currentBody, fields, functionName) def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): -- GitLab