From fc198a1751f8f27c2982b4bcaf2df6100c7e7466 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 11 Oct 2017 16:06:03 +0200
Subject: [PATCH] Further vectorization tests & bugfixes

- phasefield phi sweep vectorizes successfully
---
 backends/cbackend.py | 19 +++++++++++++++++++
 data_types.py        | 15 ++++++++++++---
 vectorization.py     |  8 +++++++-
 3 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 79872b7ca..86f5597cc 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -259,6 +259,24 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             processed = func.format(processed, summand.term)
         return processed
 
+    def _print_Pow(self, expr):
+        """Don't use std::pow function, for small integer exponents, write as multiplication"""
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr)
+        assert self.instructionSet['width'] == exprType.width
+
+        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
+            return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
+        else:
+            if expr.exp == -1:
+                one = self.instructionSet['makeVec'].format(1.0)
+                return self.instructionSet['/'].format(one, self._print(expr.base))
+            elif expr.exp == 0.5:
+                return self.instructionSet['sqrt'].format(self._print(expr.base))
+            else:
+                raise ValueError("Generic exponential not supported")
+
     def _print_Mul(self, expr, insideAdd=False):
         exprType = getTypeOfExpression(expr)
         if type(exprType) is not VectorType:
@@ -286,6 +304,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                 a.append(item)
 
         a = a or [S.One]
+        # a = a or [castFunc(S.One, VectorType(createTypeFromString("double"), exprType.width))]
 
         a_str = [self._print(x) for x in a]
         b_str = [self._print(x) for x in b]
diff --git a/data_types.py b/data_types.py
index 6f766b7aa..75bc673e4 100644
--- a/data_types.py
+++ b/data_types.py
@@ -10,7 +10,6 @@ from pystencils.utils import allEqual
 
 # to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
 class castFunc(sp.Function, sp.Rel):
-
     @property
     def canonical(self):
         if hasattr(self.args[0], 'canonical'):
@@ -18,6 +17,10 @@ class castFunc(sp.Function, sp.Rel):
         else:
             raise NotImplementedError()
 
+    @property
+    def is_commutative(self):
+        return self.args[0].is_commutative
+
 
 class pointerArithmeticFunc(sp.Function, sp.Rel):
 
@@ -281,8 +284,11 @@ def getTypeOfExpression(expr):
     elif hasattr(expr, 'func') and expr.func == castFunc:
         return expr.args[1]
     elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
-        branchResults = [a[0] for a in expr.args]
-        return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults))
+        collatedResultType = collateTypes(tuple(getTypeOfExpression(a[0]) for a in expr.args))
+        collatedConditionType = collateTypes(tuple(getTypeOfExpression(a[1]) for a in expr.args))
+        if type(collatedConditionType) is VectorType and type(collatedResultType) is not VectorType:
+            collatedResultType = VectorType(collatedResultType, width=collatedConditionType.width)
+        return collatedResultType
     elif isinstance(expr, sp.Indexed):
         typedSymbol = expr.base.label
         return typedSymbol.dtype.baseType
@@ -328,6 +334,9 @@ class Type(sp.Basic):
     def _sympystr(self, *args, **kwargs):
         return str(self)
 
+    def _sympystr(self, *args, **kwargs):
+        return str(self)
+
 
 class BasicType(Type):
     @staticmethod
diff --git a/vectorization.py b/vectorization.py
index 979666d79..9cf113c5e 100644
--- a/vectorization.py
+++ b/vectorization.py
@@ -70,6 +70,9 @@ def insertVectorCasts(astNode):
                 castedArgs = [castFunc(a, targetType) if t != targetType else a
                               for a, t in zip(newArgs, argTypes)]
                 return expr.func(*castedArgs)
+        elif expr.func is sp.Pow:
+            newArg = visitExpr(expr.args[0])
+            return sp.Pow(newArg, expr.args[1])
         elif expr.func == sp.Piecewise:
             newResults = [visitExpr(a[0]) for a in expr.args]
             newConditions = [visitExpr(a[1]) for a in expr.args]
@@ -77,10 +80,13 @@ def insertVectorCasts(astNode):
             typesOfConditions = [getTypeOfExpression(a) for a in newConditions]
 
             resultTargetType = getTypeOfExpression(expr)
+            conditionTargetType = collateTypes(typesOfConditions)
+            if type(conditionTargetType) is VectorType and type(resultTargetType) is not VectorType:
+                resultTargetType = VectorType(resultTargetType, width=conditionTargetType.width)
+
             castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a
                              for a, t in zip(newResults, typesOfResults)]
 
-            conditionTargetType = collateTypes(typesOfConditions)
             castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a
                                 for a, t in zip(newConditions, typesOfConditions)]
 
-- 
GitLab