From 979ee93bc30d2636241be6ff40b49dfcc766a292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Victor=20Tozatti=20Risso?= <joaovictortr@protonmail.com> Date: Wed, 13 Dec 2017 12:36:13 +0100 Subject: [PATCH] Code generation for field serialization into buffers Concept: Generate code involving the (un)packing of fields (from)to linear (1D) arrays, i.e. (de)serialization of the field values for buffered communication. A linear index is generated for the buffer, by inferring the strides and variables of the loops over fields in the AST. In the CPU, this information is obtained through the makeLoopOverDomain function, in pystencils/transformations/transformations.py. On CUDA, the strides of the fields (excluding buffers) are combined with the indexing variables to infer the indexing of the buffer. What is supported: - code generation for both CPU and GPU - (un)packing of fields with all the memory layouts supported by pystencils - (un)packing slices of fields (from)into the buffer - (un)packing subsets of cell values from the fields (from)into the buffer Limitations: - assumes that only one buffer and one field are being operated within each kernel, however multiple equations involving the buffer and the field are supported. - (un)packing multiple cell values (from)into the buffer is supported, however it is limited to the fields with indexDimensions=1. The same applies to (un)packing subset of cell values of each cell. Changes in this commit: - add the FieldType enumeration to pystencils/field.py, to mark fields of various types. This is replaces and is a generalization of the isIndexedField boolean flag of the Field class. For now, the types supported are: generic, indexed and buffer fields. - add the fieldType property to the Field class, which indicates the type of the field. Modifications were also performed to the member functions of the Field class to add this property. - add resolveBufferAccesses function, which replaces the fields marked as buffers with the actual field access in the AST traversal. Miscelaneous changes: - add blockDim and gridDim variables as CUDA indexing variables. --- __init__.py | 2 +- cpu/cpujit.py | 5 +- cpu/kernelcreation.py | 30 +++++++++--- field.py | 49 ++++++++++++++----- gpucuda/cudajit.py | 5 +- gpucuda/indexing.py | 6 ++- gpucuda/kernelcreation.py | 33 ++++++++++--- llvm/kernelcreation.py | 6 ++- transformations/transformations.py | 75 ++++++++++++++++++++++++++++-- 9 files changed, 171 insertions(+), 40 deletions(-) diff --git a/__init__.py b/__init__.py index b78f17d66..27b0b97f8 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from pystencils.field import Field, extractCommonSubexpressions +from pystencils.field import Field, FieldType, extractCommonSubexpressions from pystencils.data_types import TypedSymbol from pystencils.slicing import makeSlice from pystencils.kernelcreation import createKernel, createIndexedKernel diff --git a/cpu/cpujit.py b/cpu/cpujit.py index 69c3d8cd8..28184d6f3 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -74,6 +74,7 @@ from pystencils.backends.cbackend import generateC, getHeaders from collections import OrderedDict, Mapping from pystencils.transformations import symbolNameToVariableName from pystencils.data_types import toCtypes, getBaseType, StructType +from pystencils.field import FieldType def makePythonFunction(kernelFunctionNode, argumentDict={}): @@ -387,9 +388,9 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides))) - if symbolicField.isIndexField: + if FieldType.isIndexed(symbolicField): indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) - else: + elif not FieldType.isBuffer(symbolicField): arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) elif arg.isFieldShapeArgument: diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 298adea4d..91b6dc3eb 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -4,11 +4,11 @@ from functools import partial from collections import defaultdict from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction -from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \ +from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ substituteArrayAccessesWithConstants from pystencils.data_types import TypedSymbol, BasicType, StructType, createType -from pystencils.field import Field +from pystencils.field import Field, FieldType import pystencils.astnodes as ast from pystencils.cpu.cpujit import makePythonFunction @@ -51,10 +51,13 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', allFields = fieldsRead.union(fieldsWritten) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) + buffers = set([f for f in allFields if FieldType.isBuffer(f)]) + fieldsWithoutBuffers = allFields - buffers + body = ast.Block(assignments) - loopOrder = getOptimalLoopOrdering(allFields) - code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, - ghostLayers=ghostLayers, loopOrder=loopOrder) + loopOrder = getOptimalLoopOrdering(fieldsWithoutBuffers) + code, loopStrides, loopVars = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, + ghostLayers=ghostLayers, loopOrder=loopOrder) code.target = 'cpu' if splitGroups: @@ -62,8 +65,20 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', splitInnerLoop(code, typedSplitGroups) basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']] - basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields} + basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) + for field in fieldsWithoutBuffers} + + bufferBasePointerInfos = {field.name: parseBasePointerInfo([['spatialInner0']], [0], field) for field in buffers} + basePointerInfos.update(bufferBasePointerInfos) + + baseBufferIndex = loopVars[0] + stride = 1 + for idx, var in enumerate(loopVars[1:]): + curStride = loopStrides[idx] + stride *= int(curStride) if isinstance(curStride, float) else curStride + baseBufferIndex += var * stride + resolveBufferAccesses(code, baseBufferIndex, readOnlyFields) resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) substituteArrayAccessesWithConstants(code) moveConstantsBeforeLoop(code) @@ -93,7 +108,8 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ allFields = fieldsRead.union(fieldsWritten) for indexField in indexFields: - indexField.isIndexField = True + indexField.fieldType = FieldType.INDEXED + assert FieldType.isIndexed(indexField) assert indexField.spatialDimensions == 1, "Index fields have to be 1D" nonIndexFields = [f for f in allFields if f not in indexFields] diff --git a/field.py b/field.py index d545d15c8..51d08df97 100644 --- a/field.py +++ b/field.py @@ -1,3 +1,4 @@ +from enum import Enum from itertools import chain import numpy as np import sympy as sp @@ -7,6 +8,31 @@ from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFr from pystencils.sympyextensions import isIntegerSequence +class FieldType(Enum): + # generic fields + GENERIC = 0 + # index fields are currently only used for boundary handling + # the coordinates are not the loop counters in that case, but are read from this index field + INDEXED = 1 + # communication buffer, used for (un)packing data in communication. + BUFFER = 2 + + @staticmethod + def isGeneric(field): + assert isinstance(field, Field) + return field.fieldType == FieldType.GENERIC + + @staticmethod + def isIndexed(field): + assert isinstance(field, Field) + return field.fieldType == FieldType.INDEXED + + @staticmethod + def isBuffer(field): + assert isinstance(field, Field) + return field.fieldType == FieldType.BUFFER + + class Field(object): """ With fields one can formulate stencil-like update rules on structured grids. @@ -51,7 +77,7 @@ class Field(object): @staticmethod def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout='numpy', - indexShape=None): + indexShape=None, fieldType=FieldType.GENERIC): """ Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes @@ -85,7 +111,7 @@ class Field(object): shape += (1,) strides += (1,) - return Field(fieldName, dtype, layout, shape, strides) + return Field(fieldName, fieldType, dtype, layout, shape, strides) @staticmethod def createFromNumpyArray(fieldName, npArray, indexDimensions=0): @@ -114,7 +140,7 @@ class Field(object): shape += (1,) strides += (1,) - return Field(fieldName, npArray.dtype, spatialLayout, shape, strides) + return Field(fieldName, FieldType.GENERIC, npArray.dtype, spatialLayout, shape, strides) @staticmethod def createFixedSize(fieldName, shape, indexDimensions=0, dtype=np.float64, layout='numpy'): @@ -146,21 +172,20 @@ class Field(object): spatialLayout = list(layout) for i in range(spatialDimensions, len(layout)): spatialLayout.remove(i) - return Field(fieldName, dtype, tuple(spatialLayout), shape, strides) + return Field(fieldName, FieldType.GENERIC, dtype, tuple(spatialLayout), shape, strides) - def __init__(self, fieldName, dtype, layout, shape, strides): + def __init__(self, fieldName, fieldType, dtype, layout, shape, strides): """Do not use directly. Use static create* methods""" self._fieldName = fieldName + assert isinstance(fieldType, FieldType) + self.fieldType = fieldType self._dtype = createType(dtype) self._layout = normalizeLayout(layout) self.shape = shape self.strides = strides - # index fields are currently only used for boundary handling - # the coordinates are not the loop counters in that case, but are read from this index field - self.isIndexField = False def newFieldWithDifferentName(self, newName): - return Field(newName, self._dtype, self._layout, self.shape, self.strides) + return Field(newName, self.fieldType, self._dtype, self._layout, self.shape, self.strides) @property def spatialDimensions(self): @@ -243,11 +268,11 @@ class Field(object): return Field.Access(self, center)(*args, **kwargs) def __hash__(self): - return hash((self._layout, self.shape, self.strides, self._dtype, self._fieldName)) + return hash((self._layout, self.shape, self.strides, self._dtype, self.fieldType, self._fieldName)) def __eq__(self, other): - selfTuple = (self.shape, self.strides, self.name, self.dtype) - otherTuple = (other.shape, other.strides, other.name, other.dtype) + selfTuple = (self.shape, self.strides, self.name, self.dtype, self.fieldType) + otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType) return selfTuple == otherTuple PREFIX = "f" diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py index b815f5c33..3a0fe85f4 100644 --- a/gpucuda/cudajit.py +++ b/gpucuda/cudajit.py @@ -2,6 +2,7 @@ import numpy as np from pystencils.backends.cbackend import generateC from pystencils.transformations import symbolNameToVariableName from pystencils.data_types import StructType, getBaseType +from pystencils.field import FieldType def makePythonFunction(kernelFunctionNode, argumentDict={}): @@ -119,9 +120,9 @@ def _checkArguments(parameterSpecification, argumentDict): raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides))) - if symbolicField.isIndexField: + if FieldType.isIndexed(symbolicField): indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) - else: + elif not FieldType.isBuffer(symbolicField): arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) if len(arrayShapes) > 1: diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py index 99db1efc9..aa05078e8 100644 --- a/gpucuda/indexing.py +++ b/gpucuda/indexing.py @@ -11,6 +11,8 @@ AUTO_BLOCKSIZE_LIMITING = True BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')] THREAD_IDX = [TypedSymbol("threadIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')] +BLOCK_DIM = [TypedSymbol("blockDim." + coord, createType("int")) for coord in ('x', 'y', 'z')] +GRID_DIM = [TypedSymbol("gridDim." + coord, createType("int")) for coord in ('x', 'y', 'z')] class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): @@ -28,8 +30,8 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): @property def indexVariables(self): - """Sympy symbols for CUDA's block and thread indices""" - return BLOCK_IDX + THREAD_IDX + """Sympy symbols for CUDA's block and thread indices, and block and grid dimensions. """ + return BLOCK_IDX + THREAD_IDX + BLOCK_DIM + GRID_DIM @abc.abstractmethod def getCallParameters(self, arrShape): diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index 738709362..5bd2f5d7d 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -2,10 +2,10 @@ from functools import partial from pystencils.gpucuda.indexing import BlockIndexing from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape, \ - substituteArrayAccessesWithConstants + substituteArrayAccessesWithConstants, resolveBufferAccesses from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate from pystencils.data_types import TypedSymbol, BasicType, StructType -from pystencils import Field +from pystencils import Field, FieldType from pystencils.gpucuda.cudajit import makePythonFunction @@ -15,11 +15,18 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, allFields = fieldsRead.union(fieldsWritten) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) + buffers = set([f for f in allFields if FieldType.isBuffer(f)]) + fieldsWithoutBuffers = allFields - buffers + fieldAccesses = set() + numBufferAccesses = 0 for eq in listOfEquations: fieldAccesses.update(eq.atoms(Field.Access)) - commonShape = getCommonShape(allFields) + numBufferAccesses += sum([1 for access in eq.atoms(Field.Access) if FieldType.isBuffer(access.field)]) + + commonShape = getCommonShape(fieldsWithoutBuffers) + if iterationSlice is None: # determine iteration slice from ghost layers if ghostLayers is None: @@ -34,7 +41,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, for i in range(len(commonShape)): iterationSlice.append(slice(ghostLayers[i][0], -ghostLayers[i][1] if ghostLayers[i][1] > 0 else None)) - indexing = indexingCreator(field=list(allFields)[0], iterationSlice=iterationSlice) + indexing = indexingCreator(field=list(fieldsWithoutBuffers)[0], iterationSlice=iterationSlice) block = Block(assignments) block = indexing.guard(block, commonShape) @@ -46,8 +53,19 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields} coordMapping = {f.name: coordMapping for f in allFields} - resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping, - fieldToBasePointerInfo=basePointerInfos) + + loopVars = [numBufferAccesses * i for i in indexing.coordinates] + loopStrides = list(fieldsWithoutBuffers)[0].shape + + baseBufferIndex = loopVars[0] + stride = 1 + for idx, var in enumerate(loopVars[1:]): + stride *= loopStrides[idx] + baseBufferIndex += var * stride + + resolveBufferAccesses(ast, baseBufferIndex, readOnlyFields) + resolveFieldAccesses(ast, readOnlyFields, fieldToBasePointerInfo=basePointerInfos, + fieldToFixedCoordinates=coordMapping) substituteArrayAccessesWithConstants(ast) @@ -73,7 +91,8 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) for indexField in indexFields: - indexField.isIndexField = True + indexField.fieldType = FieldType.INDEXED + assert FieldType.isIndexed(indexField) assert indexField.spatialDimensions == 1, "Index fields have to be 1D" nonIndexFields = [f for f in allFields if f not in indexFields] diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 78cd1166a..121dcdd92 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,7 +1,8 @@ from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction -from pystencils.transformations import resolveFieldAccesses, \ +from pystencils.transformations import resolveFieldAccesses, resolveBufferAccesses, \ typeAllEquations, moveConstantsBeforeLoop, insertCasts from pystencils.data_types import TypedSymbol, BasicType, StructType +from pystencils.field import Field, FieldType from functools import partial from pystencils.llvm.llvmjit import makePythonFunction @@ -57,7 +58,8 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ allFields = fieldsRead.union(fieldsWritten) for indexField in indexFields: - indexField.isIndexField = True + indexField.fieldType = FieldType.INDEXED + assert FieldType.isIndexed(indexField) assert indexField.spatialDimensions == 1, "Index fields have to be 1D" nonIndexFields = [f for f in allFields if f not in indexFields] diff --git a/transformations/transformations.py b/transformations/transformations.py index d41cadc01..bf9426f32 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -6,7 +6,7 @@ import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase -from pystencils.field import Field, offsetComponentToDirectionString +from pystencils.field import Field, FieldType, offsetComponentToDirectionString from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc from pystencils.slicing import normalizeSlice import pystencils.astnodes as ast @@ -71,13 +71,15 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None """ # find correct ordering by inspecting participating FieldAccesses fieldAccesses = body.atoms(Field.Access) - fieldList = [e.field for e in fieldAccesses] + # exclude accesses to buffers from fieldList, because buffers are treated separately + fieldList = [e.field for e in fieldAccesses if not FieldType.isBuffer(e.field)] fields = set(fieldList) + numBufferAccesses = len(fieldAccesses) - len(fieldList) if loopOrder is None: loopOrder = getOptimalLoopOrdering(fields) - shape = getCommonShape(fields) + shape = getCommonShape(list(fields)) if iterationSlice is not None: iterationSlice = normalizeSlice(iterationSlice, shape) @@ -88,6 +90,11 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None if isinstance(ghostLayers, int): ghostLayers = [(ghostLayers, ghostLayers)] * len(loopOrder) + def getLoopStride(begin, end, step): + return (end - begin) / step + + loopStrides = [] + loopVars = [] currentBody = body lastLoop = None for i, loopCoordinate in enumerate(reversed(loopOrder)): @@ -97,6 +104,8 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1) lastLoop = newLoop currentBody = ast.Block([lastLoop]) + loopStrides.append(getLoopStride(begin, end, 1)) + loopVars.append(newLoop.loopCounterSymbol) else: sliceComponent = iterationSlice[loopCoordinate] if type(sliceComponent) is slice: @@ -104,11 +113,16 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step) lastLoop = newLoop currentBody = ast.Block([lastLoop]) + loopStrides.append(getLoopStride(sc.start, sc.stop, sc.step)) + loopVars.append(newLoop.loopCounterSymbol) else: assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), sp.sympify(sliceComponent)) currentBody.insertFront(assignment) - return ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName) + + loopVars = [numBufferAccesses * var for var in loopVars] + astNode = ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName) + return (astNode, loopStrides, loopVars) def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): @@ -133,7 +147,6 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1]) """ field = fieldAccess.field - offset = 0 name = "" listToHash = [] @@ -158,6 +171,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): name += "%0.6X" % (abs(hash(tuple(listToHash)))) newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype) + return newPtr, offset @@ -223,6 +237,7 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): rest = allCoordinates - specifiedCoordinates if rest: result.append(list(rest)) + return result @@ -278,6 +293,52 @@ def substituteArrayAccessesWithConstants(astNode): for a in astNode.args: substituteArrayAccessesWithConstants(a) +def resolveBufferAccesses(astNode, baseBufferIndex, readOnlyFieldNames=set()): + def visitSympyExpr(expr, enclosingBlock, sympyAssignment): + if isinstance(expr, Field.Access): + fieldAccess = expr + + # Do not apply transformation if field is not a buffer + if not FieldType.isBuffer(fieldAccess.field): + return expr + + buffer = fieldAccess.field + + dtype = PointerType(buffer.dtype, const=buffer.name in readOnlyFieldNames, restrict=True) + fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(buffer.name)), dtype) + + bufferIndex = baseBufferIndex + if len(fieldAccess.index) > 1: + raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!') + + if len(fieldAccess.index) > 0: + cellIndex = fieldAccess.index[0] + bufferIndex += cellIndex + + result = ast.ResolvedFieldAccess(fieldPtr, bufferIndex, fieldAccess.field, fieldAccess.offsets, + fieldAccess.index) + + return visitSympyExpr(result, enclosingBlock, sympyAssignment) + else: + if isinstance(expr, ast.ResolvedFieldAccess): + return expr + + newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args] + kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} + return expr.func(*newArgs, **kwargs) if newArgs else expr + + def visitNode(subAst): + if isinstance(subAst, ast.SympyAssignment): + enclosingBlock = subAst.parent + assert type(enclosingBlock) is ast.Block + subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst) + subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst) + else: + for i, a in enumerate(subAst.args): + visitNode(a) + + return visitNode(astNode) + def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): """ @@ -298,6 +359,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn if isinstance(expr, Field.Access): fieldAccess = expr field = fieldAccess.field + if field.name in fieldToBasePointerInfo: basePointerInfo = fieldToBasePointerInfo[field.name] else: @@ -324,6 +386,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn coordDict[e] = field.dtype.getElementOffset(accessedFieldName) else: coordDict[e] = fieldAccess.index[e - field.spatialDimensions] + return coordDict lastPointer = fieldPtr @@ -337,6 +400,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn lastPointer = newPtr coordDict = createCoordinateDict(basePointerInfo[0]) + _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index) @@ -344,6 +408,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn if isinstance(getBaseType(fieldAccess.field.dtype), StructType): newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0]) result = castFunc(result, newType) + return visitSympyExpr(result, enclosingBlock, sympyAssignment) else: if isinstance(expr, ast.ResolvedFieldAccess): -- GitLab