From 6fffa85a61782acc7f5822515277879c56961fbc Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 9 Nov 2016 12:12:30 +0100
Subject: [PATCH] Field.readOnly: no attribute any more -> passed as parameter
 where needed

---
 cpu/kernelcreation.py     | 7 ++-----
 field.py                  | 8 --------
 gpucuda/kernelcreation.py | 8 ++------
 transformations.py        | 9 +++++++--
 4 files changed, 11 insertions(+), 21 deletions(-)

diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py
index c959afe09..dcc3b151b 100644
--- a/cpu/kernelcreation.py
+++ b/cpu/kernelcreation.py
@@ -42,10 +42,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
     allFields = fieldsRead.union(fieldsWritten)
 
-    for field in allFields:
-        field.setReadOnly(False)
-    for field in fieldsRead - fieldsWritten:
-        field.setReadOnly()
+    readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
     body = ast.Block(assignments)
     code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, ghostLayers=ghostLayers)
@@ -59,7 +56,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
     basePointerInfo = [['spatialInner0'], ['spatialInner1']]
     basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
 
-    resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos)
+    resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
     moveConstantsBeforeLoop(code)
 
     return code
diff --git a/field.py b/field.py
index 9c5dfba03..7fa21a000 100644
--- a/field.py
+++ b/field.py
@@ -100,7 +100,6 @@ class Field:
         self._layout = layout
         self._shape = shape
         self._strides = strides
-        self._readonly = False
 
     @property
     def spatialDimensions(self):
@@ -146,13 +145,6 @@ class Field:
     def dtype(self):
         return self._dtype
 
-    @property
-    def readOnly(self):
-        return self._readonly
-
-    def setReadOnly(self, value=True):
-        self._readonly = value
-
     def __repr__(self):
         return self._fieldName
 
diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py
index 93f4233c6..60676058b 100644
--- a/gpucuda/kernelcreation.py
+++ b/gpucuda/kernelcreation.py
@@ -33,11 +33,7 @@ def getLinewiseCoordinates(field, ghostLayers):
 
 def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defaultdict(lambda: "double")):
     fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol)
-    allFields = fieldsRead.union(fieldsWritten)
-    for field in allFields:
-        field.setReadOnly(False)
-    for field in fieldsRead - fieldsWritten:
-        field.setReadOnly()
+    readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten])
 
     code = KernelFunction(Block(assignments), functionName)
     code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX)
@@ -49,7 +45,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau
     allFields = fieldsRead.union(fieldsWritten)
     basePointerInfo = [['spatialInner0']]
     basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [0, 1, 2], f) for f in allFields}
-    resolveFieldAccesses(code, fieldToFixedCoordinates={'src': coordMapping, 'dst': coordMapping},
+    resolveFieldAccesses(code, readOnlyFields, fieldToFixedCoordinates={'src': coordMapping, 'dst': coordMapping},
                          fieldToBasePointerInfo=basePointerInfos)
     # add the function which determines #blocks and #threads as additional member to KernelFunction node
     # this is used by the jit
diff --git a/transformations.py b/transformations.py
index 634d57760..a48487798 100644
--- a/transformations.py
+++ b/transformations.py
@@ -181,11 +181,12 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
     return result
 
 
-def resolveFieldAccesses(astNode, fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
+def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
     """
     Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing
 
     :param astNode: the AST root
+    :param readOnlyFieldNames: set of field names which are considered read-only
     :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created
                                    for details see :func:`parseBasePointerInfo`
     :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop
@@ -202,7 +203,7 @@ def resolveFieldAccesses(astNode, fieldToBasePointerInfo={}, fieldToFixedCoordin
                 basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))]
 
             dtype = "%s * __restrict__" % field.dtype
-            if field.readOnly:
+            if field.name in readOnlyFieldNames:
                 dtype = "const " + dtype
 
             fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype)
@@ -389,6 +390,8 @@ def typeAllEquations(eqs, typeForSymbol):
         if isinstance(term, Field.Access):
             fieldsRead.add(term.field)
             return term
+        elif isinstance(term, TypedSymbol):
+            return term
         elif isinstance(term, sp.Symbol):
             return TypedSymbol(term.name, typeForSymbol[term.name])
         else:
@@ -400,6 +403,8 @@ def typeAllEquations(eqs, typeForSymbol):
         if isinstance(term, Field.Access):
             fieldsWritten.add(term.field)
             return term
+        elif isinstance(term, TypedSymbol):
+            return term
         elif isinstance(term, sp.Symbol):
             return TypedSymbol(term.name, typeForSymbol[term.name])
         else:
-- 
GitLab