From ffd7b240916ae034bfcef611777c6678369ab8e8 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Thu, 12 Oct 2017 11:11:49 +0200
Subject: [PATCH] Vectorization bugfixes and improvements

- support for logical operators, and/or
- both phase field kernels can be vectorized now
---
 backends/cbackend.py              | 80 +++++++++++++++++++------------
 backends/simd_instruction_sets.py |  3 +-
 data_types.py                     |  2 +-
 vectorization.py                  |  3 +-
 4 files changed, 55 insertions(+), 33 deletions(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 86f5597cc..f19dec05f 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -221,6 +221,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
         self.instructionSet = instructionSet
 
+    def _scalarFallback(self, funcName, expr, *args, **kwargs):
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return getattr(super(VectorizedCustomSympyPrinter, self), funcName)(expr, *args, **kwargs)
+        else:
+            assert self.instructionSet['width'] == exprType.width
+            return None
+
     def _print_Function(self, expr):
         if expr.func == castFunc:
             arg, dtype = expr.args
@@ -232,11 +240,34 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
 
+    def _print_And(self, expr):
+        result = self._scalarFallback('_print_And', expr)
+        if result:
+            return result
+
+        argStrings = [self._print(a) for a in expr.args]
+        assert len(argStrings) > 0
+        result = argStrings[0]
+        for item in argStrings[1:]:
+            result = self.instructionSet['&'].format(result, item)
+        return result
+
+    def _print_Or(self, expr):
+        result = self._scalarFallback('_print_Or', expr)
+        if result:
+            return result
+
+        argStrings = [self._print(a) for a in expr.args]
+        assert len(argStrings) > 0
+        result = argStrings[0]
+        for item in argStrings[1:]:
+            result = self.instructionSet['|'].format(result, item)
+        return result
+
     def _print_Add(self, expr, order=None):
-        exprType = getTypeOfExpression(expr)
-        if type(exprType) is not VectorType:
-            return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order)
-        assert self.instructionSet['width'] == exprType.width
+        result = self._scalarFallback('_print_Add', expr)
+        if result:
+            return result
 
         summands = []
         for term in expr.args:
@@ -260,11 +291,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         return processed
 
     def _print_Pow(self, expr):
-        """Don't use std::pow function, for small integer exponents, write as multiplication"""
-        exprType = getTypeOfExpression(expr)
-        if type(exprType) is not VectorType:
-            return super(VectorizedCustomSympyPrinter, self)._print_Pow(expr)
-        assert self.instructionSet['width'] == exprType.width
+        result = self._scalarFallback('_print_Pow', expr)
+        if result:
+            return result
 
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
             return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
@@ -278,10 +307,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                 raise ValueError("Generic exponential not supported")
 
     def _print_Mul(self, expr, insideAdd=False):
-        exprType = getTypeOfExpression(expr)
-        if type(exprType) is not VectorType:
-            return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr)
-        assert self.instructionSet['width'] == exprType.width
+        result = self._scalarFallback('_print_Mul', expr)
+        if result:
+            return result
 
         c, e = expr.as_coeff_Mul()
         if c < 0:
@@ -328,26 +356,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
                 return result
 
     def _print_Relational(self, expr):
-        exprType = getTypeOfExpression(expr)
-        if type(exprType) is not VectorType:
-            return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr)
-        assert self.instructionSet['width'] == exprType.width
-
+        result = self._scalarFallback('_print_Relational', expr)
+        if result:
+            return result
         return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
 
     def _print_Equality(self, expr):
-        exprType = getTypeOfExpression(expr)
-        if type(exprType) is not VectorType:
-            return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr)
-        assert self.instructionSet['width'] == exprType.width
-
+        result = self._scalarFallback('_print_Equality', expr)
+        if result:
+            return result
         return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))
 
     def _print_Piecewise(self, expr):
-        exprType = getTypeOfExpression(expr)
-        if type(exprType) is not VectorType:
-            return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr)
-        assert self.instructionSet['width'] == exprType.width
+        result = self._scalarFallback('_print_Piecewise', expr)
+        if result:
+            return result
 
         if expr.args[-1].cond != True:
             # We need the last conditional to be a True, otherwise the resulting
@@ -362,6 +385,3 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         for trueExpr, condition in reversed(expr.args[:-1]):
             result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
         return result
-
-
-
diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py
index 23109165b..6760837c8 100644
--- a/backends/simd_instruction_sets.py
+++ b/backends/simd_instruction_sets.py
@@ -13,7 +13,8 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
         '<=': 'cmp[0, 1, _CMP_LE_OQ  ]',
         '<': 'cmp[0, 1, _CMP_NGE_UQ ]',
         '>': 'cmp[0, 1, _CMP_NLE_UQ ]',
-
+        '&': 'and[0, 1]',
+        '|': 'or[0, 1]',
         'blendv': 'blendv[0, 1, 2]',
 
         'sqrt': 'sqrt[0]',
diff --git a/data_types.py b/data_types.py
index 75bc673e4..d2a8a52b2 100644
--- a/data_types.py
+++ b/data_types.py
@@ -292,7 +292,7 @@ def getTypeOfExpression(expr):
     elif isinstance(expr, sp.Indexed):
         typedSymbol = expr.base.label
         return typedSymbol.dtype.baseType
-    elif isinstance(expr, sp.boolalg.Boolean):
+    elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
         # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
         result = createTypeFromString("bool")
         vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
diff --git a/vectorization.py b/vectorization.py
index 9cf113c5e..291f8a96a 100644
--- a/vectorization.py
+++ b/vectorization.py
@@ -60,7 +60,8 @@ def insertVectorCasts(astNode):
     Inserts necessary casts from scalar values to vector values
     """
     def visitExpr(expr):
-        if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc):
+        if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc) or \
+                isinstance(expr, sp.boolalg.BooleanFunction):
             newArgs = [visitExpr(a) for a in expr.args]
             argTypes = [getTypeOfExpression(a) for a in newArgs]
             if not any(type(t) is VectorType for t in argTypes):
-- 
GitLab