Commit e4ed4efc authored by Martin Bauer's avatar Martin Bauer
Browse files

Improvements in vectorization to work also with split kernels

- activated vectorization for LBM kernels
parent 26cac6b4
......@@ -384,7 +384,7 @@ class SympyAssignment(Node):
self.rhs = rhsTerm
self._isDeclaration = True
isCast = self._lhsSymbol.func == castFunc
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
self._isDeclaration = False
self._isConst = isConst
......
......@@ -149,11 +149,11 @@ class CBackend(object):
return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node):
return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
node.symbol.dtype, self.sympyPrinter.doprint(node.size))
return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size))
def _print_TemporaryMemoryFree(self, node):
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
def _print_CustomCppCode(self, node):
return node.code
......
......@@ -472,14 +472,17 @@ def splitInnerLoop(astNode, symbolGroups):
for symbol in symbolGroup:
if type(symbol) is not Field.Access:
assert type(symbol) is TypedSymbol
symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]
newTs = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
assignmentGroup = []
for assignment in innerLoop.body.args:
if assignment.lhs in symbolsResolved:
newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
assert type(assignment.lhs) is TypedSymbol
newTs = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
else:
newLhs = assignment.lhs
assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
......@@ -489,8 +492,9 @@ def splitInnerLoop(astNode, symbolGroups):
innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
for tmpArray in symbolsWithTemporaryArray:
outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))
tmpArrayPointer = TypedSymbol(tmpArray.name, PointerType(tmpArray.dtype))
outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop))
outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer))
def symbolNameToVariableName(symbolName):
......
......@@ -2,7 +2,8 @@ import sympy as sp
import warnings
from pystencils.transformations import filteredTreeIteration
from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes
from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \
PointerType
import pystencils.astnodes as ast
......@@ -31,11 +32,25 @@ def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4):
"of vectorization width can be vectorized")
continue
loopNode.step = vectorWidth
# Find all array accesses (indexed) that depend on the loop counter as offset
loopCounterSymbol = ast.LoopOverCoordinate.getLoopCounterSymbol(loopNode.coordinateToLoopOver)
substitutions = {}
successful = True
for indexed in loopNode.atoms(sp.Indexed):
base, index = indexed.args
if loopCounterSymbol in index.atoms(sp.Symbol):
loopCounterIsOffset = loopCounterSymbol not in (index - loopCounterSymbol).atoms()
if not loopCounterIsOffset:
successful = False
break
typedSymbol = base.label
assert type(typedSymbol.dtype) is PointerType, "Type of access is " + str(typedSymbol.dtype) + ", " + str(indexed)
substitutions[indexed] = castFunc(indexed, VectorType(typedSymbol.dtype.baseType, vectorWidth))
if not successful:
warnings.warn("Could not vectorize loop because of non-consecutive memory access")
continue
# All field accesses depending on loop coordinate are changed to vector type
fieldAccesses = [n for n in loopNode.atoms(ast.ResolvedFieldAccess)]
substitutions = {fa: castFunc(fa, VectorType(BasicType(fa.field.dtype), vectorWidth)) for fa in fieldAccesses}
loopNode.step = vectorWidth
loopNode.subs(substitutions)
......@@ -89,7 +104,3 @@ def insertVectorCasts(astNode):
if type(lhsType) is VectorType and type(rhsType) is not VectorType:
asmt.rhs = castFunc(asmt.rhs, lhsType)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment