Commit fd2df06b authored by Martin Bauer's avatar Martin Bauer
Browse files

Different method to substitute shape&stride accesses by symbols

- old method produced unused variables
- old method was not deterministic in the output code i.e. ordering of
  the introduced constants
- moveConstantsBeforeLoops transformation was also not deterministic
parent 198fd763
......@@ -12,6 +12,11 @@ class ResolvedFieldAccess(sp.Indexed):
obj.idxCoordinateValues = idxCoordinateValues
return obj
def _eval_subs(self, old, new):
return ResolvedFieldAccess(self.args[0],
self.args[1].subs(old, new),
self.field, self.offsets, self.idxCoordinateValues)
def _hashable_content(self):
superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
......@@ -40,6 +45,11 @@ class Node(object):
"""Symbols which are used but are not defined inside this node"""
raise NotImplementedError()
def subs(self, *args, **kwargs):
"""Inplace! substitute, similar to sympys but modifies ast and returns None"""
for a in self.args:
a.subs(*args, **kwargs)
def atoms(self, argType):
Returns a set of all children which are an instance of the given argType
......@@ -271,56 +281,45 @@ class LoopOverCoordinate(Node):
def __init__(self, body, coordinateToLoopOver, start, stop, step=1):
self._body = body
self.body = body
body.parent = self
self._coordinateToLoopOver = coordinateToLoopOver
self._begin = start
self._end = stop
self._increment = step
self._body.parent = self
self.coordinateToLoopOver = coordinateToLoopOver
self.start = start
self.stop = stop
self.step = step
self.body.parent = self
self.prefixLines = []
def newLoopWithDifferentBody(self, newBody):
result = LoopOverCoordinate(newBody, self._coordinateToLoopOver, self._begin, self._end, self._increment)
result = LoopOverCoordinate(newBody, self.coordinateToLoopOver, self.start, self.stop, self.step)
result.prefixLines = [l for l in self.prefixLines]
return result
def subs(self, *args, **kwargs):
self.body.subs(*args, **kwargs)
if hasattr(self.start, "subs"):
self.start = self.start.subs(*args, **kwargs)
if hasattr(self.stop, "subs"):
self.stop = self.stop.subs(*args, **kwargs)
if hasattr(self.step, "subs"):
self.step = self.step.subs(*args, **kwargs)
def args(self):
result = [self._body]
for e in [self._begin, self._end, self._increment]:
result = [self.body]
for e in [self.start, self.stop, self.step]:
if hasattr(e, "args"):
return result
def body(self):
return self._body
def start(self):
return self._begin
def stop(self):
return self._end
def step(self):
return self._increment
def coordinateToLoopOver(self):
return self._coordinateToLoopOver
def symbolsDefined(self):
return set([self.loopCounterSymbol])
def undefinedSymbols(self):
result = self._body.undefinedSymbols
for possibleSymbol in [self._begin, self._end, self._increment]:
result = self.body.undefinedSymbols
for possibleSymbol in [self.start, self.stop, self.step]:
if isinstance(possibleSymbol, Node) or isinstance(possibleSymbol, sp.Basic):
return result - set([self.loopCounterSymbol])
......@@ -360,10 +359,6 @@ class LoopOverCoordinate(Node):
def isInnermostLoop(self):
return len(self.atoms(LoopOverCoordinate)) == 0
def coordinateToLoopOver(self):
return self._coordinateToLoopOver
def __str__(self):
return 'loop:{!s} in {!s}:{!s}:{!s}\n{!s}'.format(self.loopCounterName, self.start, self.stop, self.step,
("\t" + "\t".join(str(self.body).splitlines(True))))
......@@ -394,6 +389,10 @@ class SympyAssignment(Node):
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
self._isDeclaration = False
def subs(self, *args, **kwargs):
self.lhs = self.lhs.subs(*args, **kwargs)
self.rhs = self.rhs.subs(*args, **kwargs)
def args(self):
return [self._lhsSymbol, self.rhs]
......@@ -2,7 +2,8 @@ import sympy as sp
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolveFieldAccesses, makeLoopOverDomain, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
from pystencils.types import TypedSymbol, BasicType, StructType, createType
from pystencils.field import Field
import pystencils.astnodes as ast
......@@ -61,6 +62,7 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
basePointerInfos = { parseBasePointerInfo(basePointerInfo, loopOrder, field) for field in allFields}
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
return code
......@@ -122,6 +124,7 @@ def createIndexedKernel(listOfEquations, indexFields, functionName="kernel", typ
fixedCoordinateMapping = { coordinateTypedSymbols for f in nonIndexFields}
resolveFieldAccesses(ast, set(['indexField']), fieldToFixedCoordinates=fixedCoordinateMapping)
return ast
from pystencils.gpucuda.indexing import BlockIndexing
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape
from pystencils.transformations import resolveFieldAccesses, typeAllEquations, parseBasePointerInfo, getCommonShape, \
from pystencils.astnodes import Block, KernelFunction, SympyAssignment, LoopOverCoordinate
from pystencils.types import TypedSymbol, BasicType, StructType
from pystencils import Field
......@@ -44,6 +45,9 @@ def createCUDAKernel(listOfEquations, functionName="kernel", typeForSymbol=None,
coordMapping = { coordMapping for f in allFields}
resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
......@@ -102,6 +106,8 @@ def createdIndexedCUDAKernel(listOfEquations, indexFields, functionName="kernel"
coordMapping.update({ coordinateTypedSymbols for f in nonIndexFields})
resolveFieldAccesses(ast, readOnlyFields, fieldToFixedCoordinates=coordMapping,
# add the function which determines #blocks and #threads as additional member to KernelFunction node
# this is used by the jit
ast.indexing = indexing
from collections import defaultdict, OrderedDict
from operator import attrgetter
from copy import deepcopy
import sympy as sp
from sympy.logic.boolalg import Boolean
from sympy.tensor import IndexedBase, Indexed
from sympy.tensor import IndexedBase
from pystencils.field import Field, offsetComponentToDirectionString
from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString
......@@ -219,32 +220,57 @@ def parseBasePointerInfo(basePointerSpecification, loopOrder, field):
return result
def substituteShapeAndStrideWithConstants(kernelFunctionAstNode):
fieldAccesses = kernelFunctionAstNode.atoms(Field.Access)
fields = {fa.field for fa in fieldAccesses}
fields = list(fields)
fields.sort(key=lambda f:
def substituteArrayAccessesWithConstants(astNode):
"""Substitutes all instances of Indexed (array acceses) that are not field accesses with constants.
Benchmarks showed that using an array access as loop bound or in pointer computations cause some compilers to do
less optimizations.
This transformation should be after field accesses have been resolved (since they introduce array accesses) and
before constants are moved before the loops.
for field in fields:
if all(isinstance(e, Indexed) for e in field.shape):
newShape = []
shapeDtype = getBaseType(createTypeFromString(Field.SHAPE_DTYPE))
shapeDtype.const = False
for i, shape in enumerate(field.shape):
symbol = TypedSymbol("shapeConst_%s_%d" % (, i), shapeDtype)
kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, shape))
field.shape = tuple(newShape)
if all(isinstance(e, Indexed) for e in field.strides):
newStrides = []
strideDtype = getBaseType(createTypeFromString(Field.STRIDE_DTYPE))
strideDtype.const = False
for i, stride in enumerate(field.strides):
symbol = TypedSymbol("strideConst_%s_%d" % (, i), strideDtype)
kernelFunctionAstNode.body.insertFront(ast.SympyAssignment(symbol, stride))
field.strides = tuple(newStrides)
def handleSympyExpression(expr, parentBlock):
"""Returns sympy expression where array accesses have been replaced with constants, together with a list
of assignments that define these constants"""
if not isinstance(expr, sp.Expr):
return expr
# get all indexed expressions that are not field accesses
indexedExprs = [e for e in expr.atoms(sp.Indexed) if not isinstance(e, ast.ResolvedFieldAccess)]
# special case: right hand side is a single indexed expression, then nothing has to be done
if len(indexedExprs) == 1 and expr == indexedExprs[0]:
return expr
constantsDefinitions = []
constantSubstitutions = {}
for indexedExpr in indexedExprs:
base, idx = indexedExpr.args
typedSymbol = base.args[0]
baseType = deepcopy(getBaseType(typedSymbol.dtype))
baseType.const = False
constantReplacingIndexed = TypedSymbol( + str(idx), baseType)
constantsDefinitions.append(ast.SympyAssignment(constantReplacingIndexed, indexedExpr))
constantSubstitutions[indexedExpr] = constantReplacingIndexed
constantsDefinitions.sort(key=lambda e:
alreadyDefined = parentBlock.symbolsDefined
for newAssignment in constantsDefinitions:
if newAssignment.lhs not in alreadyDefined:
parentBlock.insertBefore(newAssignment, astNode)
return expr.subs(constantSubstitutions)
if isinstance(astNode, ast.SympyAssignment):
astNode.rhs = handleSympyExpression(astNode.rhs, astNode.parent)
astNode.lhs = handleSympyExpression(astNode.lhs, astNode.parent)
elif isinstance(astNode, ast.LoopOverCoordinate):
astNode.start = handleSympyExpression(astNode.start, astNode.parent)
astNode.stop = handleSympyExpression(astNode.stop, astNode.parent)
astNode.step = handleSympyExpression(astNode.step, astNode.parent)
for a in astNode.args:
def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerInfo={}, fieldToFixedCoordinates={}):
......@@ -259,8 +285,6 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
counters to index the field these symbols are used as coordinates
:return: transformed AST
fieldToBasePointerInfo = OrderedDict(sorted(fieldToBasePointerInfo.items(), key=lambda pair: pair[0]))
fieldToFixedCoordinates = OrderedDict(sorted(fieldToFixedCoordinates.items(), key=lambda pair: pair[0]))
......@@ -375,7 +399,16 @@ def moveConstantsBeforeLoop(astNode):
return arg
return None
for block in astNode.atoms(ast.Block):
def getBlocks(node, resultList):
if isinstance(node, ast.Block):
resultList.insert(0, node)
if isinstance(node, ast.Node):
for a in node.args:
getBlocks(a, resultList)
allBlocks = []
getBlocks(astNode, allBlocks)
for block in allBlocks:
children = block.takeChildNodes()
for child in children:
if not isinstance(child, ast.SympyAssignment):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment