-
Martin Bauer authored
- support for logical operators, and/or - both phase field kernels can be vectorized now
ffd7b240
vectorization.py 5.28 KiB
import sympy as sp
import warnings
from pystencils.sympyextensions import fastSubs
from pystencils.transformations import filteredTreeIteration
from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \
PointerType
import pystencils.astnodes as ast
def vectorize(astNode, vectorWidth=4):
vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth)
insertVectorCasts(astNode)
def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4):
"""
Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
- loop bounds are constant
- loop range is a multiple of vector width
"""
innerLoops = [n for n in astNode.atoms(ast.LoopOverCoordinate) if n.isInnermostLoop]
for loopNode in innerLoops:
loopRange = loopNode.stop - loopNode.start
# Check restrictions
if isinstance(loopRange, sp.Basic) and not loopRange.is_integer:
warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop")
continue
if loopRange % vectorWidth != 0 or loopNode.step != 1:
warnings.warn("Currently only loops with loop bounds that are multiples "
"of vectorization width can be vectorized")
continue
# Find all array accesses (indexed) that depend on the loop counter as offset
loopCounterSymbol = ast.LoopOverCoordinate.getLoopCounterSymbol(loopNode.coordinateToLoopOver)
substitutions = {}
successful = True
for indexed in loopNode.atoms(sp.Indexed):
base, index = indexed.args
if loopCounterSymbol in index.atoms(sp.Symbol):
loopCounterIsOffset = loopCounterSymbol not in (index - loopCounterSymbol).atoms()
if not loopCounterIsOffset:
successful = False
break
typedSymbol = base.label
assert type(typedSymbol.dtype) is PointerType, "Type of access is " + str(typedSymbol.dtype) + ", " + str(indexed)
substitutions[indexed] = castFunc(indexed, VectorType(typedSymbol.dtype.baseType, vectorWidth))
if not successful:
warnings.warn("Could not vectorize loop because of non-consecutive memory access")
continue
loopNode.step = vectorWidth
loopNode.subs(substitutions)
def insertVectorCasts(astNode):
"""
Inserts necessary casts from scalar values to vector values
"""
def visitExpr(expr):
if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc) or \
isinstance(expr, sp.boolalg.BooleanFunction):
newArgs = [visitExpr(a) for a in expr.args]
argTypes = [getTypeOfExpression(a) for a in newArgs]
if not any(type(t) is VectorType for t in argTypes):
return expr
else:
targetType = collateTypes(argTypes)
castedArgs = [castFunc(a, targetType) if t != targetType else a
for a, t in zip(newArgs, argTypes)]
return expr.func(*castedArgs)
elif expr.func is sp.Pow:
newArg = visitExpr(expr.args[0])
return sp.Pow(newArg, expr.args[1])
elif expr.func == sp.Piecewise:
newResults = [visitExpr(a[0]) for a in expr.args]
newConditions = [visitExpr(a[1]) for a in expr.args]
typesOfResults = [getTypeOfExpression(a) for a in newResults]
typesOfConditions = [getTypeOfExpression(a) for a in newConditions]
resultTargetType = getTypeOfExpression(expr)
conditionTargetType = collateTypes(typesOfConditions)
if type(conditionTargetType) is VectorType and type(resultTargetType) is not VectorType:
resultTargetType = VectorType(resultTargetType, width=conditionTargetType.width)
castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a
for a, t in zip(newResults, typesOfResults)]
castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a
for a, t in zip(newConditions, typesOfConditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(castedResults, castedConditions)])
else:
return expr
substitutionDict = {}
for asmt in filteredTreeIteration(astNode, ast.SympyAssignment):
subsExpr = fastSubs(asmt.rhs, substitutionDict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
asmt.rhs = visitExpr(subsExpr)
rhsType = getTypeOfExpression(asmt.rhs)
if isinstance(asmt.lhs, TypedSymbol):
lhsType = asmt.lhs.dtype
if type(rhsType) is VectorType and type(lhsType) is not VectorType:
newLhsType = VectorType(lhsType, rhsType.width)
newLhs = TypedSymbol(asmt.lhs.name, newLhsType)
substitutionDict[asmt.lhs] = newLhs
asmt.lhs = newLhs
elif asmt.lhs.func == castFunc:
lhsType = asmt.lhs.args[1]
if type(lhsType) is VectorType and type(rhsType) is not VectorType:
asmt.rhs = castFunc(asmt.rhs, lhsType)