From e4ed4efc779ff10aae5d3d09bac51f245b7f8549 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Tue, 10 Oct 2017 11:09:04 +0200 Subject: [PATCH] Improvements in vectorization to work also with split kernels - activated vectorization for LBM kernels --- astnodes.py | 2 +- backends/cbackend.py | 6 +++--- transformations.py | 12 ++++++++---- vectorization.py | 29 ++++++++++++++++++++--------- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/astnodes.py b/astnodes.py index 001be76b6..0fd41535c 100644 --- a/astnodes.py +++ b/astnodes.py @@ -384,7 +384,7 @@ class SympyAssignment(Node): self.rhs = rhsTerm self._isDeclaration = True isCast = self._lhsSymbol.func == castFunc - if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast: + if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast: self._isDeclaration = False self._isConst = isConst diff --git a/backends/cbackend.py b/backends/cbackend.py index 2086973dd..79872b7ca 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -149,11 +149,11 @@ class CBackend(object): return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): - return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol), - node.symbol.dtype, self.sympyPrinter.doprint(node.size)) + return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name), + node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size)) def _print_TemporaryMemoryFree(self, node): - return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),) + return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),) def _print_CustomCppCode(self, node): return node.code diff --git a/transformations.py b/transformations.py index b7ac891f5..5c1fe5698 100644 --- a/transformations.py +++ b/transformations.py @@ -472,14 +472,17 @@ def splitInnerLoop(astNode, symbolGroups): for symbol in symbolGroup: if type(symbol) is not Field.Access: assert type(symbol) is TypedSymbol - symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol] + newTs = TypedSymbol(symbol.name, PointerType(symbol.dtype)) + symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol] assignmentGroup = [] for assignment in innerLoop.body.args: if assignment.lhs in symbolsResolved: newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items()) if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup: - newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol] + assert type(assignment.lhs) is TypedSymbol + newTs = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) + newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol] else: newLhs = assignment.lhs assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs)) @@ -489,8 +492,9 @@ def splitInnerLoop(astNode, symbolGroups): innerLoop.parent.replace(innerLoop, ast.Block(newLoops)) for tmpArray in symbolsWithTemporaryArray: - outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop)) - outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray)) + tmpArrayPointer = TypedSymbol(tmpArray.name, PointerType(tmpArray.dtype)) + outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop)) + outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer)) def symbolNameToVariableName(symbolName): diff --git a/vectorization.py b/vectorization.py index 3f9a359c9..8aa6a01db 100644 --- a/vectorization.py +++ b/vectorization.py @@ -2,7 +2,8 @@ import sympy as sp import warnings from pystencils.transformations import filteredTreeIteration -from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes +from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \ + PointerType import pystencils.astnodes as ast @@ -31,11 +32,25 @@ def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4): "of vectorization width can be vectorized") continue - loopNode.step = vectorWidth + # Find all array accesses (indexed) that depend on the loop counter as offset + loopCounterSymbol = ast.LoopOverCoordinate.getLoopCounterSymbol(loopNode.coordinateToLoopOver) + substitutions = {} + successful = True + for indexed in loopNode.atoms(sp.Indexed): + base, index = indexed.args + if loopCounterSymbol in index.atoms(sp.Symbol): + loopCounterIsOffset = loopCounterSymbol not in (index - loopCounterSymbol).atoms() + if not loopCounterIsOffset: + successful = False + break + typedSymbol = base.label + assert type(typedSymbol.dtype) is PointerType, "Type of access is " + str(typedSymbol.dtype) + ", " + str(indexed) + substitutions[indexed] = castFunc(indexed, VectorType(typedSymbol.dtype.baseType, vectorWidth)) + if not successful: + warnings.warn("Could not vectorize loop because of non-consecutive memory access") + continue - # All field accesses depending on loop coordinate are changed to vector type - fieldAccesses = [n for n in loopNode.atoms(ast.ResolvedFieldAccess)] - substitutions = {fa: castFunc(fa, VectorType(BasicType(fa.field.dtype), vectorWidth)) for fa in fieldAccesses} + loopNode.step = vectorWidth loopNode.subs(substitutions) @@ -89,7 +104,3 @@ def insertVectorCasts(astNode): if type(lhsType) is VectorType and type(rhsType) is not VectorType: asmt.rhs = castFunc(asmt.rhs, lhsType) - - - - -- GitLab