From e4ed4efc779ff10aae5d3d09bac51f245b7f8549 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 10 Oct 2017 11:09:04 +0200
Subject: [PATCH] Improvements in vectorization to work also with split kernels

- activated vectorization for LBM kernels
---
 astnodes.py          |  2 +-
 backends/cbackend.py |  6 +++---
 transformations.py   | 12 ++++++++----
 vectorization.py     | 29 ++++++++++++++++++++---------
 4 files changed, 32 insertions(+), 17 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 001be76b6..0fd41535c 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -384,7 +384,7 @@ class SympyAssignment(Node):
         self.rhs = rhsTerm
         self._isDeclaration = True
         isCast = self._lhsSymbol.func == castFunc
-        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, ResolvedFieldAccess) or isCast:
             self._isDeclaration = False
         self._isConst = isConst
 
diff --git a/backends/cbackend.py b/backends/cbackend.py
index 2086973dd..79872b7ca 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -149,11 +149,11 @@ class CBackend(object):
                 return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
 
     def _print_TemporaryMemoryAllocation(self, node):
-        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
-                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
+        return "%s %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol.name),
+                                        node.symbol.dtype.baseType, self.sympyPrinter.doprint(node.size))
 
     def _print_TemporaryMemoryFree(self, node):
-        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
+        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol.name),)
 
     def _print_CustomCppCode(self, node):
         return node.code
diff --git a/transformations.py b/transformations.py
index b7ac891f5..5c1fe5698 100644
--- a/transformations.py
+++ b/transformations.py
@@ -472,14 +472,17 @@ def splitInnerLoop(astNode, symbolGroups):
         for symbol in symbolGroup:
             if type(symbol) is not Field.Access:
                 assert type(symbol) is TypedSymbol
-                symbolsWithTemporaryArray[symbol] = IndexedBase(symbol, shape=(1,))[innerLoop.loopCounterSymbol]
+                newTs = TypedSymbol(symbol.name, PointerType(symbol.dtype))
+                symbolsWithTemporaryArray[symbol] = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
 
         assignmentGroup = []
         for assignment in innerLoop.body.args:
             if assignment.lhs in symbolsResolved:
                 newRhs = assignment.rhs.subs(symbolsWithTemporaryArray.items())
                 if type(assignment.lhs) is not Field.Access and assignment.lhs in symbolGroup:
-                    newLhs = IndexedBase(assignment.lhs, shape=(1,))[innerLoop.loopCounterSymbol]
+                    assert type(assignment.lhs) is TypedSymbol
+                    newTs = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
+                    newLhs = IndexedBase(newTs, shape=(1,))[innerLoop.loopCounterSymbol]
                 else:
                     newLhs = assignment.lhs
                 assignmentGroup.append(ast.SympyAssignment(newLhs, newRhs))
@@ -489,8 +492,9 @@ def splitInnerLoop(astNode, symbolGroups):
     innerLoop.parent.replace(innerLoop, ast.Block(newLoops))
 
     for tmpArray in symbolsWithTemporaryArray:
-        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArray, innerLoop.stop))
-        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArray))
+        tmpArrayPointer = TypedSymbol(tmpArray.name, PointerType(tmpArray.dtype))
+        outerLoop.parent.insertFront(ast.TemporaryMemoryAllocation(tmpArrayPointer, innerLoop.stop))
+        outerLoop.parent.append(ast.TemporaryMemoryFree(tmpArrayPointer))
 
 
 def symbolNameToVariableName(symbolName):
diff --git a/vectorization.py b/vectorization.py
index 3f9a359c9..8aa6a01db 100644
--- a/vectorization.py
+++ b/vectorization.py
@@ -2,7 +2,8 @@ import sympy as sp
 import warnings
 
 from pystencils.transformations import filteredTreeIteration
-from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes
+from pystencils.data_types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes, \
+    PointerType
 import pystencils.astnodes as ast
 
 
@@ -31,11 +32,25 @@ def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4):
                           "of vectorization width can be vectorized")
             continue
 
-        loopNode.step = vectorWidth
+        # Find all array accesses (indexed) that depend on the loop counter as offset
+        loopCounterSymbol = ast.LoopOverCoordinate.getLoopCounterSymbol(loopNode.coordinateToLoopOver)
+        substitutions = {}
+        successful = True
+        for indexed in loopNode.atoms(sp.Indexed):
+            base, index = indexed.args
+            if loopCounterSymbol in index.atoms(sp.Symbol):
+                loopCounterIsOffset = loopCounterSymbol not in (index - loopCounterSymbol).atoms()
+                if not loopCounterIsOffset:
+                    successful = False
+                    break
+                typedSymbol = base.label
+                assert type(typedSymbol.dtype) is PointerType, "Type of access is " + str(typedSymbol.dtype) + ", " + str(indexed)
+                substitutions[indexed] = castFunc(indexed, VectorType(typedSymbol.dtype.baseType, vectorWidth))
+        if not successful:
+            warnings.warn("Could not vectorize loop because of non-consecutive memory access")
+            continue
 
-        # All field accesses depending on loop coordinate are changed to vector type
-        fieldAccesses = [n for n in loopNode.atoms(ast.ResolvedFieldAccess)]
-        substitutions = {fa: castFunc(fa, VectorType(BasicType(fa.field.dtype), vectorWidth)) for fa in fieldAccesses}
+        loopNode.step = vectorWidth
         loopNode.subs(substitutions)
 
 
@@ -89,7 +104,3 @@ def insertVectorCasts(astNode):
             if type(lhsType) is VectorType and type(rhsType) is not VectorType:
                 asmt.rhs = castFunc(asmt.rhs, lhsType)
 
-
-
-
-
-- 
GitLab