diff --git a/transformations.py b/transformations.py index fb4683dda2a803e83c0ca2e0ab532713486949d2..2d4d1af24979c4590da3d920272dcff1a0e0431c 100644 --- a/transformations.py +++ b/transformations.py @@ -3,10 +3,10 @@ from operator import attrgetter import sympy as sp from sympy.logic.boolalg import Boolean -from sympy.tensor import IndexedBase +from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field, offsetComponentToDirectionString -from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType +from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString from pystencils.slicing import normalizeSlice import pystencils.astnodes as ast @@ -219,6 +219,34 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): return result +def substituteShapeAndStrideWithConstants(kernelFunctionAstNode): + fieldAccesses = kernelFunctionAstNode.atoms(Field.Access) + fields = {fa.field for fa in fieldAccesses} + fields = list(fields) + fields.sort(key=lambda f: f.name) + + for field in fields: + if all(isinstance(e, Indexed) for e in field.shape): + newShape = [] + shapeDtype = getBaseType(createTypeFromString(Field.SHAPE_DTYPE)) + shapeDtype.const = False + for i, shape in enumerate(field.shape): + symbol = TypedSymbol("shapeConst_%s_%d" % (field.name, i), shapeDtype) + kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, shape)) + newShape.append(symbol) + field.shape = tuple(newShape) + + if all(isinstance(e, Indexed) for e in field.strides): + newStrides = [] + strideDtype = getBaseType(createTypeFromString(Field.STRIDE_DTYPE)) + strideDtype.const = False + for i, stride in enumerate(field.strides): + symbol = TypedSymbol("strideConst_%s_%d" % (field.name, i), strideDtype) + kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, stride)) + newStrides.append(symbol) + field.strides = tuple(newStrides) + + def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): """ Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing @@ -231,6 +259,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn counters to index the field these symbols are used as coordinates :return: transformed AST """ + substituteShapeAndStrideWithConstants(astNode) fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0])) fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0]))