From c0f31ce6c4faedec4d8a7cd74e256154370a2c39 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 20 Sep 2017 17:11:30 +0200 Subject: [PATCH] Storing shape & strides in constants, not reading from array -> reading from arrays was slower with some compilers --- transformations.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/transformations.py b/transformations.py index fb4683dda..2d4d1af24 100644 --- a/transformations.py +++ b/transformations.py @@ -3,10 +3,10 @@ from operator import attrgetter import sympy as sp from sympy.logic.boolalg import Boolean -from sympy.tensor import IndexedBase +from sympy.tensor import IndexedBase, Indexed from pystencils.field import Field, offsetComponentToDirectionString -from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType +from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString from pystencils.slicing import normalizeSlice import pystencils.astnodes as ast @@ -219,6 +219,34 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): return result +def substituteShapeAndStrideWithConstants(kernelFunctionAstNode): + fieldAccesses = kernelFunctionAstNode.atoms(Field.Access) + fields = {fa.field for fa in fieldAccesses} + fields = list(fields) + fields.sort(key=lambda f: f.name) + + for field in fields: + if all(isinstance(e, Indexed) for e in field.shape): + newShape = [] + shapeDtype = getBaseType(createTypeFromString(Field.SHAPE_DTYPE)) + shapeDtype.const = False + for i, shape in enumerate(field.shape): + symbol = TypedSymbol("shapeConst_%s_%d" % (field.name, i), shapeDtype) + kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, shape)) + newShape.append(symbol) + field.shape = tuple(newShape) + + if all(isinstance(e, Indexed) for e in field.strides): + newStrides = [] + strideDtype = getBaseType(createTypeFromString(Field.STRIDE_DTYPE)) + strideDtype.const = False + for i, stride in enumerate(field.strides): + symbol = TypedSymbol("strideConst_%s_%d" % (field.name, i), strideDtype) + kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, stride)) + newStrides.append(symbol) + field.strides = tuple(newStrides) + + def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): """ Substitutes :class:`pystencils.field.Field.Access` nodes by array indexing @@ -231,6 +259,7 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn counters to index the field these symbols are used as coordinates :return: transformed AST """ + substituteShapeAndStrideWithConstants(astNode) fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0])) fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0])) -- GitLab