diff --git a/__init__.py b/__init__.py index b78f17d66f89640855bc6950f7ddd752d2bdf838..27b0b97f8ba261d4d85ad9ace8876a92acf0eb9e 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from pystencils.field import Field, extractCommonSubexpressions +from pystencils.field import Field, FieldType, extractCommonSubexpressions from pystencils.data_types import TypedSymbol from pystencils.slicing import makeSlice from pystencils.kernelcreation import createKernel, createIndexedKernel diff --git a/cpu/cpujit.py b/cpu/cpujit.py index 69c3d8cd84b991ba86719733dfc44e80b813771f..28184d6f366d20904eb3d958e3c98af18339a8e5 100644 --- a/cpu/cpujit.py +++ b/cpu/cpujit.py @@ -74,6 +74,7 @@ from pystencils.backends.cbackend import generateC, getHeaders from collections import OrderedDict, Mapping from pystencils.transformations import symbolNameToVariableName from pystencils.data_types import toCtypes, getBaseType, StructType +from pystencils.field import FieldType def makePythonFunction(kernelFunctionNode, argumentDict={}): @@ -387,9 +388,9 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides))) - if symbolicField.isIndexField: + if FieldType.isIndexed(symbolicField): indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) - else: + elif not FieldType.isBuffer(symbolicField): arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) elif arg.isFieldShapeArgument: diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index 298adea4d43b0428de94d74aaed6c8dac3891982..91b6dc3ebc296c3ba1ccd240a4081a0edc1019b6 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -4,11 +4,11 @@ from functools import partial from collections import defaultdict from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction -from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \ +from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ substituteArrayAccessesWithConstants from pystencils.data_types import TypedSymbol, BasicType, StructType, createType -from pystencils.field import Field +from pystencils.field import Field, FieldType import pystencils.astnodes as ast from pystencils.cpu.cpujit import makePythonFunction @@ -51,10 +51,13 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', allFields = fieldsRead.union(fieldsWritten) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) + buffers = set([f for f in allFields if FieldType.isBuffer(f)]) + fieldsWithoutBuffers = allFields - buffers + body = ast.Block(assignments) - loopOrder = getOptimalLoopOrdering(allFields) - code = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, - ghostLayers=ghostLayers, loopOrder=loopOrder) + loopOrder = getOptimalLoopOrdering(fieldsWithoutBuffers) + code, loopStrides, loopVars = makeLoopOverDomain(body, functionName, iterationSlice=iterationSlice, + ghostLayers=ghostLayers, loopOrder=loopOrder) code.target = 'cpu' if splitGroups: @@ -62,8 +65,20 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol='double', splitInnerLoop(code, typedSplitGroups) basePointerInfo = [['spatialInner0'], ['spatialInner1']] if len(loopOrder) >= 2 else [['spatialInner0']] - basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields} + basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) + for field in fieldsWithoutBuffers} + + bufferBasePointerInfos = {field.name: parseBasePointerInfo([['spatialInner0']], [0], field) for field in buffers} + basePointerInfos.update(bufferBasePointerInfos) + + baseBufferIndex = loopVars[0] + stride = 1 + for idx, var in enumerate(loopVars[1:]): + curStride = loopStrides[idx] + stride *= int(curStride) if isinstance(curStride, float) else curStride + baseBufferIndex += var * stride + resolveBufferAccesses(code, baseBufferIndex, readOnlyFields) resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) substituteArrayAccessesWithConstants(code) moveConstantsBeforeLoop(code) @@ -93,7 +108,8 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ allFields = fieldsRead.union(fieldsWritten) for indexField in indexFields: - indexField.isIndexField = True + indexField.fieldType = FieldType.INDEXED + assert FieldType.isIndexed(indexField) assert indexField.spatialDimensions == 1, "Index fields have to be 1D" nonIndexFields = [f for f in allFields if f not in indexFields] diff --git a/field.py b/field.py index d545d15c8eb4852365a2a9ce88197f4dc26e0182..51d08df97501a4dd89fc641778ce5a845f5e7d49 100644 --- a/field.py +++ b/field.py @@ -1,3 +1,4 @@ +from enum import Enum from itertools import chain import numpy as np import sympy as sp @@ -7,6 +8,31 @@ from pystencils.data_types import TypedSymbol, createType, createCompositeTypeFr from pystencils.sympyextensions import isIntegerSequence +class FieldType(Enum): + # generic fields + GENERIC = 0 + # index fields are currently only used for boundary handling + # the coordinates are not the loop counters in that case, but are read from this index field + INDEXED = 1 + # communication buffer, used for (un)packing data in communication. + BUFFER = 2 + + @staticmethod + def isGeneric(field): + assert isinstance(field, Field) + return field.fieldType == FieldType.GENERIC + + @staticmethod + def isIndexed(field): + assert isinstance(field, Field) + return field.fieldType == FieldType.INDEXED + + @staticmethod + def isBuffer(field): + assert isinstance(field, Field) + return field.fieldType == FieldType.BUFFER + + class Field(object): """ With fields one can formulate stencil-like update rules on structured grids. @@ -51,7 +77,7 @@ class Field(object): @staticmethod def createGeneric(fieldName, spatialDimensions, dtype=np.float64, indexDimensions=0, layout='numpy', - indexShape=None): + indexShape=None, fieldType=FieldType.GENERIC): """ Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes @@ -85,7 +111,7 @@ class Field(object): shape += (1,) strides += (1,) - return Field(fieldName, dtype, layout, shape, strides) + return Field(fieldName, fieldType, dtype, layout, shape, strides) @staticmethod def createFromNumpyArray(fieldName, npArray, indexDimensions=0): @@ -114,7 +140,7 @@ class Field(object): shape += (1,) strides += (1,) - return Field(fieldName, npArray.dtype, spatialLayout, shape, strides) + return Field(fieldName, FieldType.GENERIC, npArray.dtype, spatialLayout, shape, strides) @staticmethod def createFixedSize(fieldName, shape, indexDimensions=0, dtype=np.float64, layout='numpy'): @@ -146,21 +172,20 @@ class Field(object): spatialLayout = list(layout) for i in range(spatialDimensions, len(layout)): spatialLayout.remove(i) - return Field(fieldName, dtype, tuple(spatialLayout), shape, strides) + return Field(fieldName, FieldType.GENERIC, dtype, tuple(spatialLayout), shape, strides) - def __init__(self, fieldName, dtype, layout, shape, strides): + def __init__(self, fieldName, fieldType, dtype, layout, shape, strides): """Do not use directly. Use static create* methods""" self._fieldName = fieldName + assert isinstance(fieldType, FieldType) + self.fieldType = fieldType self._dtype = createType(dtype) self._layout = normalizeLayout(layout) self.shape = shape self.strides = strides - # index fields are currently only used for boundary handling - # the coordinates are not the loop counters in that case, but are read from this index field - self.isIndexField = False def newFieldWithDifferentName(self, newName): - return Field(newName, self._dtype, self._layout, self.shape, self.strides) + return Field(newName, self.fieldType, self._dtype, self._layout, self.shape, self.strides) @property def spatialDimensions(self): @@ -243,11 +268,11 @@ class Field(object): return Field.Access(self, center)(*args, **kwargs) def __hash__(self): - return hash((self._layout, self.shape, self.strides, self._dtype, self._fieldName)) + return hash((self._layout, self.shape, self.strides, self._dtype, self.fieldType, self._fieldName)) def __eq__(self, other): - selfTuple = (self.shape, self.strides, self.name, self.dtype) - otherTuple = (other.shape, other.strides, other.name, other.dtype) + selfTuple = (self.shape, self.strides, self.name, self.dtype, self.fieldType) + otherTuple = (other.shape, other.strides, other.name, other.dtype, other.fieldType) return selfTuple == otherTuple PREFIX = "f" diff --git a/gpucuda/cudajit.py b/gpucuda/cudajit.py index b815f5c33ef2f1f24ac0fb8c32cab68c964d18e6..3a0fe85f4ceec904438381b4104b29fcfb2f1d30 100644 --- a/gpucuda/cudajit.py +++ b/gpucuda/cudajit.py @@ -2,6 +2,7 @@ import numpy as np from pystencils.backends.cbackend import generateC from pystencils.transformations import symbolNameToVariableName from pystencils.data_types import StructType, getBaseType +from pystencils.field import FieldType def makePythonFunction(kernelFunctionNode, argumentDict={}): @@ -119,9 +120,9 @@ def _checkArguments(parameterSpecification, argumentDict): raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" % (arg.fieldName, str(fieldArr.strides), str(symbolicFieldStrides))) - if symbolicField.isIndexField: + if FieldType.isIndexed(symbolicField): indexArrShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) - else: + elif not FieldType.isBuffer(symbolicField): arrayShapes.add(fieldArr.shape[:symbolicField.spatialDimensions]) if len(arrayShapes) > 1: diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py index 99db1efc974c11b0d3a5f911bf86bdf51094a4cc..aa05078e884b325a7c59a42aedb2d96745c3629d 100644 --- a/gpucuda/indexing.py +++ b/gpucuda/indexing.py @@ -11,6 +11,8 @@ AUTO_BLOCKSIZE_LIMITING = True BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')] THREAD_IDX = [TypedSymbol("threadIdx." + coord, createType("int")) for coord in ('x', 'y', 'z')] +BLOCK_DIM = [TypedSymbol("blockDim." + coord, createType("int")) for coord in ('x', 'y', 'z')] +GRID_DIM = [TypedSymbol("gridDim." + coord, createType("int")) for coord in ('x', 'y', 'z')] class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): @@ -28,8 +30,8 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): @property def indexVariables(self): - """Sympy symbols for CUDA's block and thread indices""" - return BLOCK_IDX + THREAD_IDX + """Sympy symbols for CUDA's block and thread indices, and block and grid dimensions. """ + return BLOCK_IDX + THREAD_IDX + BLOCK_DIM + GRID_DIM @abc.abstractmethod def getCallParameters(self, arrShape): diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index 7387093620757f4966f1c77f2e758834eba185ac..5bd2f5d7d96efda99bf4122483a50325cde7f2d3 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -2,10 +2,10 @@ from functools import partial from pystencils.gpucuda.indexing import BlockIndexing from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape, \ - substituteArrayAccessesWithConstants + substituteArrayAccessesWithConstants, resolveBufferAccesses from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate from pystencils.data_types import TypedSymbol, BasicType, StructType -from pystencils import Field +from pystencils import Field, FieldType from pystencils.gpucuda.cudajit import makePythonFunction @@ -15,11 +15,18 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, allFields = fieldsRead.union(fieldsWritten) readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) + buffers = set([f for f in allFields if FieldType.isBuffer(f)]) + fieldsWithoutBuffers = allFields - buffers + fieldAccesses = set() + numBufferAccesses = 0 for eq in listOfEquations: fieldAccesses.update(eq.atoms(Field.Access)) - commonShape = getCommonShape(allFields) + numBufferAccesses += sum([1 for access in eq.atoms(Field.Access) if FieldType.isBuffer(access.field)]) + + commonShape = getCommonShape(fieldsWithoutBuffers) + if iterationSlice is None: # determine iteration slice from ghost layers if ghostLayers is None: @@ -34,7 +41,7 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, for i in range(len(commonShape)): iterationSlice.append(slice(ghostLayers[i][0], -ghostLayers[i][1] if ghostLayers[i][1] > 0 else None)) - indexing = indexingCreator(field=list(allFields)[0], iterationSlice=iterationSlice) + indexing = indexingCreator(field=list(fieldsWithoutBuffers)[0], iterationSlice=iterationSlice) block = Block(assignments) block = indexing.guard(block, commonShape) @@ -46,8 +53,19 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, basePointerInfos = {f.name: parseBasePointerInfo(basePointerInfo, [2, 1, 0], f) for f in allFields} coordMapping = {f.name: coordMapping for f in allFields} - resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping, - fieldToBasePointerInfo=basePointerInfos) + + loopVars = [numBufferAccesses * i for i in indexing.coordinates] + loopStrides = list(fieldsWithoutBuffers)[0].shape + + baseBufferIndex = loopVars[0] + stride = 1 + for idx, var in enumerate(loopVars[1:]): + stride *= loopStrides[idx] + baseBufferIndex += var * stride + + resolveBufferAccesses(ast, baseBufferIndex, readOnlyFields) + resolveFieldAccesses(ast, readOnlyFields, fieldToBasePointerInfo=basePointerInfos, + fieldToFixedCoordinates=coordMapping) substituteArrayAccessesWithConstants(ast) @@ -73,7 +91,8 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" readOnlyFields = set([f.name for f in fieldsRead - fieldsWritten]) for indexField in indexFields: - indexField.isIndexField = True + indexField.fieldType = FieldType.INDEXED + assert FieldType.isIndexed(indexField) assert indexField.spatialDimensions == 1, "Index fields have to be 1D" nonIndexFields = [f for f in allFields if f not in indexFields] diff --git a/llvm/kernelcreation.py b/llvm/kernelcreation.py index 78cd1166a3f4daa9a025784a9b109127079fc661..121dcdd924a84f90d94ca557f5ac19cf5a299802 100644 --- a/llvm/kernelcreation.py +++ b/llvm/kernelcreation.py @@ -1,7 +1,8 @@ from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction -from pystencils.transformations import resolveFieldAccesses, \ +from pystencils.transformations import resolveFieldAccesses, resolveBufferAccesses, \ typeAllEquations, moveConstantsBeforeLoop, insertCasts from pystencils.data_types import TypedSymbol, BasicType, StructType +from pystencils.field import Field, FieldType from functools import partial from pystencils.llvm.llvmjit import makePythonFunction @@ -57,7 +58,8 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ allFields = fieldsRead.union(fieldsWritten) for indexField in indexFields: - indexField.isIndexField = True + indexField.fieldType = FieldType.INDEXED + assert FieldType.isIndexed(indexField) assert indexField.spatialDimensions == 1, "Index fields have to be 1D" nonIndexFields = [f for f in allFields if f not in indexFields] diff --git a/transformations/transformations.py b/transformations/transformations.py index d41cadc013a7ec630bc5e11a6fed02ddf90335ba..bf9426f326150b5d996305c3afc7b729f9bdd80e 100644 --- a/transformations/transformations.py +++ b/transformations/transformations.py @@ -6,7 +6,7 @@ import sympy as sp from sympy.logic.boolalg import Boolean from sympy.tensor import IndexedBase -from pystencils.field import Field, offsetComponentToDirectionString +from pystencils.field import Field, FieldType, offsetComponentToDirectionString from pystencils.data_types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc from pystencils.slicing import normalizeSlice import pystencils.astnodes as ast @@ -71,13 +71,15 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None """ # find correct ordering by inspecting participating FieldAccesses fieldAccesses = body.atoms(Field.Access) - fieldList = [e.field for e in fieldAccesses] + # exclude accesses to buffers from fieldList, because buffers are treated separately + fieldList = [e.field for e in fieldAccesses if not FieldType.isBuffer(e.field)] fields = set(fieldList) + numBufferAccesses = len(fieldAccesses) - len(fieldList) if loopOrder is None: loopOrder = getOptimalLoopOrdering(fields) - shape = getCommonShape(fields) + shape = getCommonShape(list(fields)) if iterationSlice is not None: iterationSlice = normalizeSlice(iterationSlice, shape) @@ -88,6 +90,11 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None if isinstance(ghostLayers, int): ghostLayers = [(ghostLayers, ghostLayers)] * len(loopOrder) + def getLoopStride(begin, end, step): + return (end - begin) / step + + loopStrides = [] + loopVars = [] currentBody = body lastLoop = None for i, loopCoordinate in enumerate(reversed(loopOrder)): @@ -97,6 +104,8 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, begin, end, 1) lastLoop = newLoop currentBody = ast.Block([lastLoop]) + loopStrides.append(getLoopStride(begin, end, 1)) + loopVars.append(newLoop.loopCounterSymbol) else: sliceComponent = iterationSlice[loopCoordinate] if type(sliceComponent) is slice: @@ -104,11 +113,16 @@ def makeLoopOverDomain(body, functionName, iterationSlice=None, ghostLayers=None newLoop = ast.LoopOverCoordinate(currentBody, loopCoordinate, sc.start, sc.stop, sc.step) lastLoop = newLoop currentBody = ast.Block([lastLoop]) + loopStrides.append(getLoopStride(sc.start, sc.stop, sc.step)) + loopVars.append(newLoop.loopCounterSymbol) else: assignment = ast.SympyAssignment(ast.LoopOverCoordinate.getLoopCounterSymbol(loopCoordinate), sp.sympify(sliceComponent)) currentBody.insertFront(assignment) - return ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName) + + loopVars = [numBufferAccesses * var for var in loopVars] + astNode = ast.KernelFunction(currentBody, ghostLayers=ghostLayers, functionName=functionName) + return (astNode, loopStrides, loopVars) def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): @@ -133,7 +147,6 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): (ptr_E_2S, x*fstride_myfield[0] + y*fstride_myfield[1] + fstride_myfield[0] - 2*fstride_myfield[1]) """ field = fieldAccess.field - offset = 0 name = "" listToHash = [] @@ -158,6 +171,7 @@ def createIntermediateBasePointer(fieldAccess, coordinates, previousPtr): name += "%0.6X" % (abs(hash(tuple(listToHash)))) newPtr = TypedSymbol(previousPtr.name + name, previousPtr.dtype) + return newPtr, offset @@ -223,6 +237,7 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): rest = allCoordinates - specifiedCoordinates if rest: result.append(list(rest)) + return result @@ -278,6 +293,52 @@ def substituteArrayAccessesWithConstants(astNode): for a in astNode.args: substituteArrayAccessesWithConstants(a) +def resolveBufferAccesses(astNode, baseBufferIndex, readOnlyFieldNames=set()): + def visitSympyExpr(expr, enclosingBlock, sympyAssignment): + if isinstance(expr, Field.Access): + fieldAccess = expr + + # Do not apply transformation if field is not a buffer + if not FieldType.isBuffer(fieldAccess.field): + return expr + + buffer = fieldAccess.field + + dtype = PointerType(buffer.dtype, const=buffer.name in readOnlyFieldNames, restrict=True) + fieldPtr = TypedSymbol("%s%s" % (Field.DATA_PREFIX, symbolNameToVariableName(buffer.name)), dtype) + + bufferIndex = baseBufferIndex + if len(fieldAccess.index) > 1: + raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!') + + if len(fieldAccess.index) > 0: + cellIndex = fieldAccess.index[0] + bufferIndex += cellIndex + + result = ast.ResolvedFieldAccess(fieldPtr, bufferIndex, fieldAccess.field, fieldAccess.offsets, + fieldAccess.index) + + return visitSympyExpr(result, enclosingBlock, sympyAssignment) + else: + if isinstance(expr, ast.ResolvedFieldAccess): + return expr + + newArgs = [visitSympyExpr(e, enclosingBlock, sympyAssignment) for e in expr.args] + kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} + return expr.func(*newArgs, **kwargs) if newArgs else expr + + def visitNode(subAst): + if isinstance(subAst, ast.SympyAssignment): + enclosingBlock = subAst.parent + assert type(enclosingBlock) is ast.Block + subAst.lhs = visitSympyExpr(subAst.lhs, enclosingBlock, subAst) + subAst.rhs = visitSympyExpr(subAst.rhs, enclosingBlock, subAst) + else: + for i, a in enumerate(subAst.args): + visitNode(a) + + return visitNode(astNode) + def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): """ @@ -298,6 +359,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn if isinstance(expr, Field.Access): fieldAccess = expr field = fieldAccess.field + if field.name in fieldToBasePointerInfo: basePointerInfo = fieldToBasePointerInfo[field.name] else: @@ -324,6 +386,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn coordDict[e] = field.dtype.getElementOffset(accessedFieldName) else: coordDict[e] = fieldAccess.index[e - field.spatialDimensions] + return coordDict lastPointer = fieldPtr @@ -337,6 +400,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn lastPointer = newPtr coordDict = createCoordinateDict(basePointerInfo[0]) + _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer) result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index) @@ -344,6 +408,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn if isinstance(getBaseType(fieldAccess.field.dtype), StructType): newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0]) result = castFunc(result, newType) + return visitSympyExpr(result, enclosingBlock, sympyAssignment) else: if isinstance(expr, ast.ResolvedFieldAccess):