From b207d071982367983a122a0759387dbaedf2d98f Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 11 Nov 2016 16:23:25 +0100
Subject: [PATCH] Changed symbolsRead/ symbolsDefined semantics

problem in moveConstantBeforeLoops transformation:

--> a should end up here
{
for() {
  const int a = 5;
}
for() {
  const int a = 5
}
}

the "a" of the lower loop was not moved up, since it could not move across first loop (which is wrong)
---
 ast.py                    | 56 +++++++++++++++++++++------------------
 backends/cbackend.py      |  4 +--
 gpucuda/kernelcreation.py |  2 +-
 transformations.py        |  2 +-
 4 files changed, 34 insertions(+), 30 deletions(-)

diff --git a/ast.py b/ast.py
index 4669f3b3e..2ed8732a9 100644
--- a/ast.py
+++ b/ast.py
@@ -17,13 +17,13 @@ class Node(object):
 
     @property
     def symbolsDefined(self):
-        """Set of symbols which are defined in this node or its children"""
+        """Set of symbols which are defined by this node. """
         return set()
 
     @property
-    def symbolsRead(self):
-        """Set of symbols which are accessed/read in this node or its children"""
-        return set()
+    def undefinedSymbols(self):
+        """Symbols which are use but are not defined inside this node"""
+        raise NotImplementedError()
 
     def atoms(self, argType):
         """
@@ -78,14 +78,15 @@ class KernelFunction(Node):
         self._parameters = None
         self._functionName = functionName
         self._body.parent = self
-        self.variablesToIgnore = set()
+        # these variables are assumed to be global, so no automatic parameter is generated for them
+        self.globalVariables = set()
 
     @property
     def symbolsDefined(self):
         return set()
 
     @property
-    def symbolsRead(self):
+    def undefinedSymbols(self):
         return set()
 
     @property
@@ -106,7 +107,7 @@ class KernelFunction(Node):
         return self._functionName
 
     def _updateParameters(self):
-        undefinedSymbols = self._body.symbolsRead - self._body.symbolsDefined - self.variablesToIgnore
+        undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
         self._parameters = [KernelFunction.Argument(s.name, s.dtype) for s in undefinedSymbols]
         self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
                                              l.isFieldStrideArgument, l.name),
@@ -169,11 +170,13 @@ class Block(Node):
         return result
 
     @property
-    def symbolsRead(self):
+    def undefinedSymbols(self):
         result = set()
+        definedSymbols = set()
         for a in self.args:
-            result.update(a.symbolsRead)
-        return result
+            result.update(a.undefinedSymbols)
+            definedSymbols.update(a.symbolsDefined)
+        return result - definedSymbols
 
     def children(self):
         yield self._nodes
@@ -235,9 +238,15 @@ class LoopOverCoordinate(Node):
 
     @property
     def symbolsDefined(self):
-        result = self._body.symbolsDefined
-        result.add(self.loopCounterSymbol)
-        return result
+        return set([self.loopCounterSymbol])
+
+    @property
+    def undefinedSymbols(self):
+        result = self._body.undefinedSymbols
+        for possibleSymbol in [self._begin, self._end, self._increment]:
+            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
+                result.update(possibleSymbol.atoms(sp.Symbol))
+        return result - set([self.loopCounterSymbol])
 
     @staticmethod
     def getLoopCounterName(coordinateToLoopOver):
@@ -255,15 +264,6 @@ class LoopOverCoordinate(Node):
     def loopCounterSymbol(self):
         return LoopOverCoordinate.getLoopCounterSymbol(self.coordinateToLoopOver)
 
-    @property
-    def symbolsRead(self):
-        loopBoundSymbols = set()
-        for possibleSymbol in [self._begin, self._end, self._increment]:
-            if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
-                loopBoundSymbols.update(possibleSymbol.atoms(sp.Symbol))
-        result = self._body.symbolsRead.union(loopBoundSymbols)
-        return result
-
     @property
     def isOutermostLoop(self):
         from pystencils.transformations import getNextParentOfType
@@ -316,8 +316,9 @@ class SympyAssignment(Node):
             return set()
         return set([self._lhsSymbol])
 
+
     @property
-    def symbolsRead(self):
+    def undefinedSymbols(self):
         result = self.rhs.atoms(sp.Symbol)
         result.update(self._lhsSymbol.atoms(sp.Symbol))
         return result
@@ -344,8 +345,11 @@ class TemporaryMemoryAllocation(Node):
         return set([self.symbol])
 
     @property
-    def symbolsRead(self):
-        return set()
+    def undefinedSymbols(self):
+        if isinstance(self.size, sp.Basic):
+            return self.size.atoms(sp.Symbol)
+        else:
+            return set()
 
     @property
     def args(self):
@@ -361,7 +365,7 @@ class TemporaryMemoryFree(Node):
         return set()
 
     @property
-    def symbolsRead(self):
+    def undefinedSymbols(self):
         return set()
 
     @property
diff --git a/backends/cbackend.py b/backends/cbackend.py
index 314c301bf..96c9a8584 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -37,8 +37,8 @@ class CustomCppCode(Node):
         return self._symbolsDefined
 
     @property
-    def symbolsRead(self):
-        return self._symbolsRead
+    def undefinedSymbols(self):
+        return self.symbolsDefined - self._symbolsRead
 
 
 class PrintNode(CustomCppCode):
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 60676058b..a0905a82e 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -36,7 +36,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
     readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
     code = KernelFunction(Block(assignments), functionName)
-    code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX)
+    code.globalVariables.update(BLOCK_IDX + THREAD_IDX)
 
     fieldAccesses = code.atoms(Field.Access)
     requiredGhostLayers = max([fa.requiredGhostLayers for fa in fieldAccesses])
diff --git a/transformations.py b/transformations.py
index db2e2e5c4..73c8e31a6 100644
--- a/transformations.py
+++ b/transformations.py
@@ -278,7 +278,7 @@ def moveConstantsBeforeLoop(astNode):
             if isinstance(element, ast.Block):
                 lastBlock = element
                 lastBlockChild = prevElement
-            if node.symbolsRead.intersection(element.symbolsDefined):
+            if node.undefinedSymbols.intersection(element.symbolsDefined):
                 break
             prevElement = element
             element = element.parent
-- 
GitLab