From fd2df06b83ef0a22f4fefc465344585738b78c0e Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 26 Sep 2017 11:15:43 +0200 Subject: [PATCH] Different method to substitute shape&stride accesses by symbols - old method produced unused variables - old method was not deterministic in the output code i.e. ordering of the introduced constants - moveConstantsBeforeLoops transformation was also not deterministic --- astnodes.py | 69 +++++++++++++++-------------- cpu/kernelcreation.py | 5 ++- gpucuda/kernelcreation.py | 8 +++- transformations.py | 91 ++++++++++++++++++++++++++------------- 4 files changed, 107 insertions(+), 66 deletions(-) diff --git a/astnodes.py b/astnodes.py index fd8292c62..f5489b62f 100644 --- a/astnodes.py +++ b/astnodes.py @@ -12,6 +12,11 @@ class ResolvedFieldAccess(sp.Indexed): obj.idxCoordinateValues = idxCoordinateValues return obj + def _eval_subs(self, old, new): + return ResolvedFieldAccess(self.args[0], + self.args[1].subs(old, new), + self.field, self.offsets, self.idxCoordinateValues) + def _hashable_content(self): superClassContents = super(ResolvedFieldAccess, self)._hashable_content() return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) @@ -40,6 +45,11 @@ class Node(object): """Symbols which are used but are not defined inside this node""" raise NotImplementedError() + def subs(self, *args, **kwargs): + """Inplace! substitute, similar to sympys but modifies ast and returns None""" + for a in self.args: + a.subs(*args, **kwargs) + def atoms(self, argType): """ Returns a set of all children which are an instance of the given argType @@ -271,56 +281,45 @@ class LoopOverCoordinate(Node): LOOP_COUNTER_NAME_PREFIX = "ctr" def __init__(self, body, coordinateToLoopOver, start, stop, step=1): - self._body = body + self.body = body body.parent = self - self._coordinateToLoopOver = coordinateToLoopOver - self._begin = start - self._end = stop - self._increment = step - self._body.parent = self + self.coordinateToLoopOver = coordinateToLoopOver + self.start = start + self.stop = stop + self.step = step + self.body.parent = self self.prefixLines = [] def newLoopWithDifferentBody(self, newBody): - result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._begin, self._end, self._increment) + result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step) result.prefixLines = [l for l in self.prefixLines] return result + def subs(self, *args, **kwargs): + self.body.subs(*args, **kwargs) + if hasattr(self.start, "subs"): + self.start = self.start.subs(*args, **kwargs) + if hasattr(self.stop, "subs"): + self.stop = self.stop.subs(*args, **kwargs) + if hasattr(self.step, "subs"): + self.step = self.step.subs(*args, **kwargs) + @property def args(self): - result = [self._body] - for e in [self._begin, self._end, self._increment]: + result = [self.body] + for e in [self.start, self.stop, self.step]: if hasattr(e, "args"): result.append(e) return result - @property - def body(self): - return self._body - - @property - def start(self): - return self._begin - - @property - def stop(self): - return self._end - - @property - def step(self): - return self._increment - - @property - def coordinateToLoopOver(self): - return self._coordinateToLoopOver - @property def symbolsDefined(self): return set([self.loopCounterSymbol]) @property def undefinedSymbols(self): - result = self._body.undefinedSymbols - for possibleSymbol in [self._begin, self._end, self._increment]: + result = self.body.undefinedSymbols + for possibleSymbol in [self.start, self.stop, self.step]: if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic): result.update(possibleSymbol.atoms(sp.Symbol)) return result - set([self.loopCounterSymbol]) @@ -360,10 +359,6 @@ class LoopOverCoordinate(Node): def isInnermostLoop(self): return len(self.atoms(LoopOverCoordinate)) == 0 - @property - def coordinateToLoopOver(self): - return self._coordinateToLoopOver - def __str__(self): return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step, ("\t" + "\t".join(str(self.body).splitlines(True)))) @@ -394,6 +389,10 @@ class SympyAssignment(Node): if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: self._isDeclaration = False + def subs(self, *args, **kwargs): + self.lhs = self.lhs.subs(*args, **kwargs) + self.rhs = self.rhs.subs(*args, **kwargs) + @property def args(self): return [self._lhsSymbol, self.rhs] diff --git a/cpu/kernelcreation.py b/cpu/kernelcreation.py index d466f2e91..18d382d72 100644 --- a/cpu/kernelcreation.py +++ b/cpu/kernelcreation.py @@ -2,7 +2,8 @@ import sympy as sp from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \ - typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop + typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ + substituteArrayAccessesWithConstants from pystencils.types import TypedSymbol, BasicType, StructType, createType from pystencils.field import Field import pystencils.astnodes as ast @@ -61,6 +62,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl basePointerInfos = {field.name: parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields} resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos) + substituteArrayAccessesWithConstants(code) moveConstantsBeforeLoop(code) return code @@ -122,6 +124,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ fixedCoordinateMapping = {f.name: coordinateTypedSymbols for f in nonIndexFields} resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping) + substituteArrayAccessesWithConstants(ast) moveConstantsBeforeLoop(ast) return ast diff --git a/gpucuda/kernelcreation.py b/gpucuda/kernelcreation.py index 84c00ba26..fd0c17564 100644 --- a/gpucuda/kernelcreation.py +++ b/gpucuda/kernelcreation.py @@ -1,5 +1,6 @@ from pystencils.gpucuda.indexing import BlockIndexing -from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape +from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape, \ + substituteArrayAccessesWithConstants from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate from pystencils.types import TypedSymbol, BasicType, StructType from pystencils import Field @@ -44,6 +45,9 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None, coordMapping = {f.name: coordMapping for f in allFields} resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping, fieldToBasePointerInfo=basePointerInfos) + + substituteArrayAccessesWithConstants(ast) + # add the function which determines #blocks and #threads as additional member to KernelFunction node # this is used by the jit @@ -102,6 +106,8 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel" coordMapping.update({f.name: coordinateTypedSymbols for f in nonIndexFields}) resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping, fieldToBasePointerInfo=basePointerInfos) + substituteArrayAccessesWithConstants(ast) + # add the function which determines #blocks and #threads as additional member to KernelFunction node # this is used by the jit ast.indexing = indexing diff --git a/transformations.py b/transformations.py index 74250d658..917a705b5 100644 --- a/transformations.py +++ b/transformations.py @@ -1,9 +1,10 @@ from collections import defaultdict, OrderedDict from operator import attrgetter +from copy import deepcopy import sympy as sp from sympy.logic.boolalg import Boolean -from sympy.tensor import IndexedBase, Indexed +from sympy.tensor import IndexedBase from pystencils.field import Field, offsetComponentToDirectionString from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString @@ -219,32 +220,57 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field): return result -def substituteShapeAndStrideWithConstants(kernelFunctionAstNode): - fieldAccesses = kernelFunctionAstNode.atoms(Field.Access) - fields = {fa.field for fa in fieldAccesses} - fields = list(fields) - fields.sort(key=lambda f: f.name) +def substituteArrayAccessesWithConstants(astNode): + """Substitutes all instances of Indexed (array acceses) that are not field accesses with constants. + Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do + less optimizations. + This transformation should be after field accesses have been resolved (since they introduce array accesses) and + before constants are moved before the loops. + """ - for field in fields: - if all(isinstance(e, Indexed) for e in field.shape): - newShape = [] - shapeDtype = getBaseType(createTypeFromString(Field.SHAPE_DTYPE)) - shapeDtype.const = False - for i, shape in enumerate(field.shape): - symbol = TypedSymbol("shapeConst_%s_%d" % (field.name, i), shapeDtype) - kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, shape)) - newShape.append(symbol) - field.shape = tuple(newShape) - - if all(isinstance(e, Indexed) for e in field.strides): - newStrides = [] - strideDtype = getBaseType(createTypeFromString(Field.STRIDE_DTYPE)) - strideDtype.const = False - for i, stride in enumerate(field.strides): - symbol = TypedSymbol("strideConst_%s_%d" % (field.name, i), strideDtype) - kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, stride)) - newStrides.append(symbol) - field.strides = tuple(newStrides) + def handleSympyExpression(expr, parentBlock): + """Returns sympy expression where array accesses have been replaced with constants, together with a list + of assignments that define these constants""" + if not isinstance(expr, sp.Expr): + return expr + + # get all indexed expressions that are not field accesses + indexedExprs = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)] + + # special case: right hand side is a single indexed expression, then nothing has to be done + if len(indexedExprs) == 1 and expr == indexedExprs[0]: + return expr + + constantsDefinitions = [] + constantSubstitutions = {} + for indexedExpr in indexedExprs: + base, idx = indexedExpr.args + typedSymbol = base.args[0] + baseType = deepcopy(getBaseType(typedSymbol.dtype)) + baseType.const = False + constantReplacingIndexed = TypedSymbol(typedSymbol.name + str(idx), baseType) + constantsDefinitions.append(ast.SympyAssignment(constantReplacingIndexed, indexedExpr)) + constantSubstitutions[indexedExpr] = constantReplacingIndexed + constantsDefinitions.sort(key=lambda e: e.lhs.name) + + alreadyDefined = parentBlock.symbolsDefined + for newAssignment in constantsDefinitions: + if newAssignment.lhs not in alreadyDefined: + parentBlock.insertBefore(newAssignment, astNode) + + return expr.subs(constantSubstitutions) + + if isinstance(astNode, ast.SympyAssignment): + astNode.rhs = handleSympyExpression(astNode.rhs, astNode.parent) + astNode.lhs = handleSympyExpression(astNode.lhs, astNode.parent) + elif isinstance(astNode, ast.LoopOverCoordinate): + astNode.start = handleSympyExpression(astNode.start, astNode.parent) + astNode.stop = handleSympyExpression(astNode.stop, astNode.parent) + astNode.step = handleSympyExpression(astNode.step, astNode.parent) + substituteArrayAccessesWithConstants(astNode.body) + else: + for a in astNode.args: + substituteArrayAccessesWithConstants(a) def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}): @@ -259,8 +285,6 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn counters to index the field these symbols are used as coordinates :return: transformed AST """ - substituteShapeAndStrideWithConstants(astNode) - fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0])) fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0])) @@ -375,7 +399,16 @@ def moveConstantsBeforeLoop(astNode): return arg return None - for block in astNode.atoms(ast.Block): + def getBlocks(node, resultList): + if isinstance(node, ast.Block): + resultList.insert(0, node) + if isinstance(node, ast.Node): + for a in node.args: + getBlocks(a, resultList) + + allBlocks = [] + getBlocks(astNode, allBlocks) + for block in allBlocks: children = block.takeChildNodes() for child in children: if not isinstance(child, ast.SympyAssignment): -- GitLab