From c744ab1d36051acbeb7e9c08d3a20a4f29a03939 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 15 Nov 2016 08:16:50 +0100
Subject: [PATCH] Added accessed fields to Ast Function node

---
 ast.py                    | 9 +++++++--
 gpucuda/kernelcreation.py | 2 +-
 transformations.py        | 2 +-
 3 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/ast.py b/ast.py
index 8c8fc9f8b..d5aa77482 100644
--- a/ast.py
+++ b/ast.py
@@ -69,13 +69,14 @@ class KernelFunction(Node):
         def __repr__(self):
             return '<{0} {1}>'.format(self.dtype, self.name)
 
-    def __init__(self, body, functionName="kernel"):
+    def __init__(self, body, fieldsAccessed, functionName="kernel"):
         super(KernelFunction, self).__init__()
         self._body = body
         body.parent = self
         self._parameters = None
         self._functionName = functionName
         self._body.parent = self
+        self._fieldsAccessed = fieldsAccessed
         # these variables are assumed to be global, so no automatic parameter is generated for them
         self.globalVariables = set()
 
@@ -104,6 +105,11 @@ class KernelFunction(Node):
     def functionName(self):
         return self._functionName
 
+    @property
+    def fieldsAccessed(self):
+        """Set of Field instances: fields which are accessed inside this kernel function"""
+        return self._fieldsAccessed
+
     def _updateParameters(self):
         undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
         self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
@@ -325,7 +331,6 @@ class SympyAssignment(Node):
             return set()
         return set([self._lhsSymbol])
 
-
     @property
     def undefinedSymbols(self):
         result = self.rhs.atoms(sp.Symbol)
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index a0905a82e..2313e518d 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -35,7 +35,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
     fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
-    code = KernelFunction(Block(assignments), functionName)
+    code = KernelFunction(Block(assignments), fieldsRead.union(fieldsWritten), functionName)
     code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
 
     fieldAccesses = code.atoms(Field.Access)
diff --git a/transformations.py b/transformations.py
index 73c8e31a6..6f92f08e2 100644
--- a/transformations.py
+++ b/transformations.py
@@ -63,7 +63,7 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None
                 assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate),
                                                  sp.sympify(sliceComponent))
                 currentBody.insertFront(assignment)
-    return ast.KernelFunction(currentBody, functionName)
+    return ast.KernelFunction(currentBody, fields, functionName)
 
 
 def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr):
-- 
GitLab