From 6fffa85a61782acc7f5822515277879c56961fbc Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 9 Nov 2016 12:12:30 +0100 Subject: [PATCH] Field.readOnly: no attribute any more -> passed as parameter where needed --- cpu/kernelcreation.py | 7 ++----- field.py | 8 -------- gpucuda/kernelcreation.py | 8 ++------ transformations.py | 9 +++++++-- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index c959afe09..dcc3b151b 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -42,10 +42,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) allFields = fieldsRead.union(fieldsWritten) - for field in allFields: - field.setReadOnly(False) - for field in fieldsRead - fieldsWritten: - field.setReadOnly() + readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) body = ast.Block(assignments) code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, ghostLayers=ghostLayers) @@ -59,7 +56,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl basePointerInfo = [['spatialInner0'], ['spatialInner1']] basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields} - resolveFieldAccesses(code, fieldToBasePointerInfo=basePointerInfos) + resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) moveConstantsBeforeLoop(code) return code diff --git a/field.py b/field.py index 9c5dfba03..7fa21a000 100644 --- a/field.py +++ b/field.py @@ -100,7 +100,6 @@ class Field: self._layout = layout self._shape = shape self._strides = strides - self._readonly = False @property def spatialDimensions(self): @@ -146,13 +145,6 @@ class Field: def dtype(self): return self._dtype - @property - def readOnly(self): - return self._readonly - - def setReadOnly(self, value=True): - self._readonly = value - def __repr__(self): return self._fieldName diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index 93f4233c6..60676058b 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -33,11 +33,7 @@ def getLinewiseCoordinates(field, ghostLayers): def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defaultdict(lambda: "double")): fieldsRead, fieldsWritten, assignments = typeAllEquations(listOfEquations, typeForSymbol) - allFields = fieldsRead.union(fieldsWritten) - for field in allFields: - field.setReadOnly(False) - for field in fieldsRead - fieldsWritten: - field.setReadOnly() + readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) code = KernelFunction(Block(assignments), functionName) code.variablesToIgnore.update(BLOCK_IDX + THREAD_IDX) @@ -49,7 +45,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=defau allFields = fieldsRead.union(fieldsWritten) basePointerInfo = [['spatialInner0']] basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [0, 1, 2], f) for f in allFields} - resolveFieldAccesses(code, fieldToFixedCoordinates={'src': coordMapping, 'dst': coordMapping}, + resolveFieldAccesses(code, readOnlyFields, fieldToFixedCoordinates={'src': coordMapping, 'dst': coordMapping}, fieldToBasePointerInfo=basePointerInfos) # add the function which determines #blocks and #threads as additional member to KernelFunction node # this is used by the jit diff --git a/transformations.py b/transformations.py index 634d57760..a48487798 100644 --- a/transformations.py +++ b/transformations.py @@ -181,11 +181,12 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): return result -def resolveFieldAccesses(astNode, fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): +def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): """ Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing :param astNode: the AST root + :param readOnlyFieldNames: set of field names which are considered read-only :param fieldToBasePointerInfo: a list of tuples indicating which intermediate base pointers should be created for details see :func:`parseBasePointerInfo` :param fieldToFixedCoordinates: map of field name to a tuple of coordinate symbols. Instead of using the loop @@ -202,7 +203,7 @@ def resolveFieldAccesses(astNode, fieldToBasePointerInfo={}, fieldToFixedCoordin basePointerInfo = [list(range(field.indexDimensions + field.spatialDimensions))] dtype = "%s * __restrict__" % field.dtype - if field.readOnly: + if field.name in readOnlyFieldNames: dtype = "const " + dtype fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, field.name), dtype) @@ -389,6 +390,8 @@ def typeAllEquations(eqs, typeForSymbol): if isinstance(term, Field.Access): fieldsRead.add(term.field) return term + elif isinstance(term, TypedSymbol): + return term elif isinstance(term, sp.Symbol): return TypedSymbol(term.name, typeForSymbol[term.name]) else: @@ -400,6 +403,8 @@ def typeAllEquations(eqs, typeForSymbol): if isinstance(term, Field.Access): fieldsWritten.add(term.field) return term + elif isinstance(term, TypedSymbol): + return term elif isinstance(term, sp.Symbol): return TypedSymbol(term.name, typeForSymbol[term.name]) else: -- GitLab