From 2e015cf51e822cb419a4f4e4faa8a8a4b0ae9e0d Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 9 Oct 2017 14:47:45 +0200
Subject: [PATCH] Vectorization: Piecewise/blend support

- sympys piecewise defined functions are mapped to blend instructions
- cast function is now a class
- several bugfixes
---
 astnodes.py                       |  8 ++--
 backends/cbackend.py              | 75 +++++++++++++++++++------------
 backends/simd_instruction_sets.py |  3 +-
 cpu/cpujit.py                     |  2 +-
 types.py                          | 22 +++++++--
 vectorization.py                  | 43 +++++++++---------
 6 files changed, 94 insertions(+), 59 deletions(-)

diff --git a/astnodes.py b/astnodes.py
index 16c1e9829..cdaf8fbc2 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -1,7 +1,7 @@
 import sympy as sp
 from sympy.tensor import IndexedBase
 from pystencils.field import Field
-from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString
+from pystencils.types import TypedSymbol, createType, get_type_from_sympy, createTypeFromString, castFunc
 
 
 class ResolvedFieldAccess(sp.Indexed):
@@ -383,8 +383,8 @@ class SympyAssignment(Node):
         self._lhsSymbol = lhsSymbol
         self.rhs = rhsTerm
         self._isDeclaration = True
-        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
-        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase) or isCast:
+        isCast = self._lhsSymbol.func == castFunc
+        if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
             self._isDeclaration = False
         self._isConst = isConst
 
@@ -396,7 +396,7 @@ class SympyAssignment(Node):
     def lhs(self, newValue):
         self._lhsSymbol = newValue
         self._isDeclaration = True
-        isCast = str(self._lhsSymbol.func).lower() == 'cast' if hasattr(self._lhsSymbol, "func") else False
+        isCast = self._lhsSymbol.func == castFunc
         if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.Indexed) or isCast:
             self._isDeclaration = False
 
diff --git a/backends/cbackend.py b/backends/cbackend.py
index 5bbb9b787..98d4b9c23 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -39,6 +39,8 @@ def getHeaders(astNode):
             headers.update(getHeaders(a))
 
     return headers
+
+
 # --------------------------------------- Backend Specific Nodes -------------------------------------------------------
 
 
@@ -75,6 +77,7 @@ class PrintNode(CustomCppCode):
 
 # ------------------------------------------- Printer ------------------------------------------------------------------
 
+
 class CBackend(object):
 
     def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
@@ -204,8 +207,7 @@ class CustomSympyPrinter(CCodePrinter):
         return res
 
     def _print_Function(self, expr):
-        name = str(expr.func).lower()
-        if name == 'cast':
+        if expr.func == castFunc:
             arg, type = expr.args
             return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
         else:
@@ -220,8 +222,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
         self.instructionSet = instructionSet
 
     def _print_Function(self, expr):
-        name = str(expr.func).lower()
-        if name == 'cast':
+        if expr.func == castFunc:
             arg, dtype = expr.args
             if type(dtype) is VectorType:
                 if type(arg) is ResolvedFieldAccess:
@@ -291,41 +292,57 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
 
         result = a_str[0]
         for item in a_str[1:]:
-            result = self.intrinsics['*'].format(result, item)
+            result = self.instructionSet['*'].format(result, item)
 
         if len(b) > 0:
             denominator_str = b_str[0]
             for item in b_str[1:]:
-                denominator_str = self.intrinsics['*'].format(denominator_str, item)
-            result = self.intrinsics['/'].format(result, denominator_str)
+                denominator_str = self.instructionSet['*'].format(denominator_str, item)
+            result = self.instructionSet['/'].format(result, denominator_str)
 
         if insideAdd:
             return sign, result
         else:
             if sign < 0:
-                return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
+                return self.instructionSet['*'].format(self._print(S.NegativeOne), result)
             else:
                 return result
 
-#    def _print_Piecewise(self, expr):
-#        if expr.args[-1].cond != True:
-#            # We need the last conditional to be a True, otherwise the resulting
-#            # function may not return a result.
-#            raise ValueError("All Piecewise expressions must contain an "
-#                             "(expr, True) statement to be used as a default "
-#                             "condition. Without one, the generated "
-#                             "expression may not evaluate to anything under "
-#                             "some condition.")
-#
-#        result = self._print(expr.args[-1][0])
-#        for trueExpr, condition in reversed(expr.args[:-1]):
-#            result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
-#        return result
-#
-#    def _print_Relational(self, expr):
-#        return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
-#
-#    def _print_Equality(self, expr):
-#        return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
-#
+    def _print_Relational(self, expr):
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return super(VectorizedCustomSympyPrinter, self)._print_Relational(expr)
+        assert self.instructionSet['width'] == exprType.width
+
+        return self.instructionSet[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs))
+
+    def _print_Equality(self, expr):
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return super(VectorizedCustomSympyPrinter, self)._print_Equality(expr)
+        assert self.instructionSet['width'] == exprType.width
+
+        return self.instructionSet['=='].format(self._print(expr.lhs), self._print(expr.rhs))
+
+    def _print_Piecewise(self, expr):
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return super(VectorizedCustomSympyPrinter, self)._print_Piecewise(expr)
+        assert self.instructionSet['width'] == exprType.width
+
+        if expr.args[-1].cond != True:
+            # We need the last conditional to be a True, otherwise the resulting
+            # function may not return a result.
+            raise ValueError("All Piecewise expressions must contain an "
+                             "(expr, True) statement to be used as a default "
+                             "condition. Without one, the generated "
+                             "expression may not evaluate to anything under "
+                             "some condition.")
+
+        result = self._print(expr.args[-1][0])
+        for trueExpr, condition in reversed(expr.args[:-1]):
+            result = self.instructionSet['blendv'].format(result, self._print(trueExpr), self._print(condition))
+        return result
+
+
 
diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py
index 213b4c481..23109165b 100644
--- a/backends/simd_instruction_sets.py
+++ b/backends/simd_instruction_sets.py
@@ -66,7 +66,7 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
             if arg in ('0', '1', '2', '3', '4', '5'):
                 argString += "{" + arg + "},"
             else:
-                argString += arg
+                argString += arg + ","
         argString = argString[:-1] + ")"
         result[intrinsicId] = pre + "_" + name + "_" + suf + argString
 
@@ -80,6 +80,7 @@ def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
     result['double'] = "__m%dd" % (bitWidth,)
     result['float'] = "__m%d" % (bitWidth,)
     result['int'] = "__m%di" % (bitWidth,)
+    result['bool'] = "__m%dd" % (bitWidth,)
 
     result['headers'] = headers[instructionSet]
     return result
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index c5919fc6c..e893a7771 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -302,7 +302,7 @@ def runCompileStep(command):
         subprocess.check_output(command, env=compileEnvironment, stderr=subprocess.STDOUT, shell=shell)
     except subprocess.CalledProcessError as e:
         print(" ".join(command))
-        print(e.output)
+        print(e.output.decode('utf8'))
         raise e
 
 
diff --git a/types.py b/types.py
index 23c487501..6442147cd 100644
--- a/types.py
+++ b/types.py
@@ -6,7 +6,16 @@ from sympy.core.cache import cacheit
 from pystencils.cache import memorycache
 from pystencils.utils import allEqual
 
-castFunc = sp.Function("cast")
+
+# to work in conditions of sp.Piecewise castFunc has to be of type Relational as well
+class castFunc(sp.Function, sp.Rel):
+
+    @property
+    def canonical(self):
+        if hasattr(self.args[0], 'canonical'):
+            return self.args[0].canonical
+        else:
+            raise NotImplementedError()
 
 
 class TypedSymbol(sp.Symbol):
@@ -202,11 +211,16 @@ def getTypeOfExpression(expr):
     elif isinstance(expr, sp.Indexed):
         typedSymbol = expr.base.label
         return typedSymbol.dtype
+    elif isinstance(expr, sp.boolalg.Boolean):
+        # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
+        result = createTypeFromString("bool")
+        vecArgs = [getTypeOfExpression(a) for a in expr.args if isinstance(getTypeOfExpression(a), VectorType)]
+        if vecArgs:
+            result = VectorType(result, width=vecArgs[0].width)
+        return result
     elif isinstance(expr, sp.Expr):
         types = tuple(getTypeOfExpression(a) for a in expr.args)
         return collateTypes(types)
-    elif isinstance(expr, sp.boolalg.Boolean):
-        return createTypeFromString("bool")
 
     raise NotImplementedError("Could not determine type for " + str(expr))
 
@@ -344,6 +358,8 @@ class VectorType(Type):
                 return self.instructionSet['double']
             elif self.baseType == createTypeFromString("float"):
                 return self.instructionSet['float']
+            elif self.baseType == createTypeFromString("bool"):
+                return self.instructionSet['bool']
             else:
                 raise NotImplementedError()
 
diff --git a/vectorization.py b/vectorization.py
index 54a9819f7..c933a5fe0 100644
--- a/vectorization.py
+++ b/vectorization.py
@@ -2,26 +2,16 @@ import sympy as sp
 import warnings
 
 from pystencils.transformations import filteredTreeIteration
-from pystencils.types import TypedSymbol, VectorType, PointerType, BasicType, getTypeOfExpression, castFunc
+from pystencils.types import TypedSymbol, VectorType, BasicType, getTypeOfExpression, castFunc, collateTypes
 import pystencils.astnodes as ast
-from pystencils.utils import allEqual
 
 
-def asVectorType(resolvedFieldAccess, vectorizationWidth):
-    """Returns a new ResolvedFieldAccess that has a vector type"""
-    dtype = resolvedFieldAccess.typedSymbol.dtype
-    assert type(dtype) is PointerType
-    basicType = dtype.baseType
-    assert type(basicType) is BasicType, "Structs are not supported"
-
-    newDtype = VectorType(basicType, vectorizationWidth)
-    newDtype = PointerType(newDtype, dtype.const, dtype.restrict)
-    newTypedSymbol = TypedSymbol(resolvedFieldAccess.typedSymbol.name, newDtype)
-    return ast.ResolvedFieldAccess(newTypedSymbol, resolvedFieldAccess.args[1], resolvedFieldAccess.field,
-                                   resolvedFieldAccess.offsets, resolvedFieldAccess.idxCoordinateValues)
+def vectorize(astNode, vectorWidth=4):
+    vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth)
+    insertVectorCasts(astNode)
 
 
-def vectorize(astNode, vectorWidth=4):
+def vectorizeInnerLoopsAndAdaptLoadStores(astNode, vectorWidth=4):
     """
     Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
         - loop bounds are constant
@@ -54,20 +44,31 @@ def insertVectorCasts(astNode):
     Inserts necessary casts from scalar values to vector values
     """
     def visitExpr(expr):
-        if expr.func in (sp.Add, sp.Mul):
+        if expr.func in (sp.Add, sp.Mul) or (isinstance(expr, sp.Rel) and not expr.func == castFunc):
             newArgs = [visitExpr(a) for a in expr.args]
             argTypes = [getTypeOfExpression(a) for a in newArgs]
             if not any(type(t) is VectorType for t in argTypes):
                 return expr
             else:
-                vectorWidths = [d.width for d in argTypes if type(d) is VectorType]
-                assert allEqual(vectorWidths), "Incompatible vector type widths"
-                vectorWidth = vectorWidths[0]
-                castedArgs = [castFunc(a, VectorType(t, vectorWidth)) if type(t) is not VectorType else a
+                targetType = collateTypes(argTypes)
+                castedArgs = [castFunc(a, targetType) if t != targetType else a
                               for a, t in zip(newArgs, argTypes)]
                 return expr.func(*castedArgs)
         elif expr.func == sp.Piecewise:
-            raise NotImplementedError()
+            newResults = [visitExpr(a[0]) for a in expr.args]
+            newConditions = [visitExpr(a[1]) for a in expr.args]
+            typesOfResults = [getTypeOfExpression(a) for a in newResults]
+            typesOfConditions = [getTypeOfExpression(a) for a in newConditions]
+
+            resultTargetType = getTypeOfExpression(expr)
+            castedResults = [castFunc(a, resultTargetType) if t != resultTargetType else a
+                             for a, t in zip(newResults, typesOfResults)]
+
+            conditionTargetType = collateTypes(typesOfConditions)
+            castedConditions = [castFunc(a, conditionTargetType) if t != conditionTargetType and a != True else a
+                                for a, t in zip(newConditions, typesOfConditions)]
+
+            return sp.Piecewise(*[(r, c) for r, c in zip(castedResults, castedConditions)])
         else:
             return expr
 
-- 
GitLab