From f1b61821b3a528fb624b8c26641d1d6b6a832147 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 7 Jul 2017 17:09:56 +0200
Subject: [PATCH] Vectorized backend

---
 backends/cbackend.py            |   4 +-
 backends/cbackend_vectorized.py | 307 ++++++++++++++++++++++++++++++++
 2 files changed, 310 insertions(+), 1 deletion(-)
 create mode 100644 backends/cbackend_vectorized.py

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 821a94733..a37cc9b1b 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -1,3 +1,4 @@
+import sympy as sp
 from sympy.utilities.codegen import CCodePrinter
 from pystencils.astnodes import Node
 from pystencils.types import createType, PointerType
@@ -134,7 +135,7 @@ class CustomSympyPrinter(CCodePrinter):
     def _print_Pow(self, expr):
         """Don't use std::pow function, for small integer exponents, write as multiplication"""
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
-            return '(' + '*'.join(["(" + self._print(expr.base) + ")"] * expr.exp) + ')'
+            return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))
         else:
             return super(CustomSympyPrinter, self)._print_Pow(expr)
 
@@ -167,3 +168,4 @@ class CustomSympyPrinter(CCodePrinter):
             return "*((%s)(& %s))" % (PointerType(type), self._print(arg))
         else:
             return super(CustomSympyPrinter, self)._print_Function(expr)
+
diff --git a/backends/cbackend_vectorized.py b/backends/cbackend_vectorized.py
new file mode 100644
index 000000000..7c70df448
--- /dev/null
+++ b/backends/cbackend_vectorized.py
@@ -0,0 +1,307 @@
+from collections import namedtuple
+
+import sympy as sp
+from sympy.core import S
+from sympy.utilities.codegen import CCodePrinter
+from sympy.core.mul import _keep_coeff
+
+from pystencils.backends.cbackend import CustomSympyPrinter
+from pystencils.types import getBaseType, createTypeFromString
+
+
+def getInstructionSetInfoIntel(dataType='double', instructionSet='avx'):
+    baseNames = {
+        '+': 'add[0, 1]',
+        '-': 'sub[0, 1]',
+        '*': 'mul[0, 1]',
+        '/': 'div[0, 1]',
+
+        '==': 'cmp[0, 1, _CMP_EQ_UQ  ]',
+        '!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
+        '>=': 'cmp[0, 1, _CMP_GE_OQ  ]',
+        '<=': 'cmp[0, 1, _CMP_LE_OQ  ]',
+        '<': 'cmp[0, 1, _CMP_NGE_UQ ]',
+        '>': 'cmp[0, 1, _CMP_NLE_UQ ]',
+
+        'blendv': 'blendv[0, 1, 2]',
+
+        'sqrt': 'sqrt[0]',
+
+        'makeVec':  'set[0,0,0,0]',
+        'makeZero': 'setzero[]',
+
+        'loadU': 'loadu [0]',
+        'loadA': 'load [0]',
+        'storeU': 'storeu[0]',
+        'storeA': 'store [0]',
+    }
+
+    suffix = {
+        'double': 'pd',
+        'float': 'ps',
+    }
+    prefix = {
+        'sse': '_mm',
+        'avx': '_mm256',
+        'avx512': '_mm512',
+    }
+
+    width = {
+        ("double", "sse"): 2,
+        ("float", "sse"): 4,
+        ("double", "avx"): 4,
+        ("float", "avx"): 8,
+        ("double", "avx512"): 8,
+        ("float", "avx512"): 16,
+    }
+
+    result = {}
+    pre = prefix[instructionSet]
+    suf = suffix[dataType]
+    for intrinsicId, functionShortcut in baseNames.items():
+        functionShortcut = functionShortcut.strip()
+        name = functionShortcut[:functionShortcut.index('[')]
+        args = functionShortcut[functionShortcut.index('[') + 1: -1]
+        argString = "("
+        for arg in args.split(","):
+            arg = arg.strip()
+            if not arg:
+                continue
+            if arg in ('0', '1', '2', '3', '4', '5'):
+                argString += "{" + arg + "},"
+            else:
+                argString += arg
+        argString = argString[:-1] + ")"
+        result[intrinsicId] = pre + "_" + name + "_" + suf + argString
+
+    result['width'] = width[(dataType, instructionSet)]
+    result['dataTypePrefix'] = {
+        'double': "_" + pre + 'd',
+        'float': "_" + pre,
+    }
+
+    return result
+
+
+class VectorizedCBackend(object):
+
+    def __init__(self, astNode, instructionSet='avx'):
+        fieldTypes = set([getBaseType(f.dtype) for f in astNode.fieldsAccessed])
+        if len(fieldTypes) != 1:
+            raise ValueError("Vectorized backend does not support kernels with mixed field types")
+        fieldType = fieldTypes.pop()
+        assert fieldType.is_float
+        dtypeName = str(fieldType)
+
+        instructionSetInfo = getInstructionSetInfoIntel(dtypeName, instructionSet)
+
+        self.vectorizationWidth = instructionSetInfo['width']
+        self.sympyVecPrinter = CustomSympyPrinterVectorized(instructionSetInfo)
+        self.sympyPrinter = CustomSympyPrinter(constantsAsFloats=(dtypeName == 'float'))
+
+        self._indent = "   "
+        self._vecTypeName = instructionSetInfo['dataTypePrefix'][dtypeName]
+        self.dtypeName = dtypeName
+
+    def __call__(self, node):
+        return str(self._print(node))
+
+    def _print(self, node):
+        for cls in type(node).__mro__:
+            methodName = "_print_" + cls.__name__
+            if hasattr(self, methodName):
+                return getattr(self, methodName)(node)
+        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
+
+    def _print_KernelFunction(self, node):
+        blockContents = "\n".join([self._print(child) for child in node.body.args])
+        constantBlock = self.sympyVecPrinter.getConstantsBlock(self._vecTypeName)
+
+        body = "{\n%s\n%s\n}" % (constantBlock, self._indent + self._indent.join(blockContents.splitlines(True)))
+
+        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
+        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
+        return funcDeclaration + "\n" + body
+
+    def _print_Block(self, node):
+        blockContents = "\n".join([self._print(child) for child in node.args])
+        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)),)
+
+    def _print_PragmaBlock(self, node):
+        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
+
+    def _print_LoopOverCoordinate(self, node):
+        if node.isInnermostLoop:
+            iterRange = node.stop - node.start
+            if isinstance(iterRange, sp.Basic) and not iterRange.is_integer:
+                raise NotImplementedError("Vectorized backend currently only supports fixed size inner loops")
+            if iterRange % self.vectorizationWidth != 0 or node.step != 1:
+                raise NotImplementedError("Vectorized backend only supports loop bounds that are "
+                                          "multiples of vectorization width")
+            step = self.vectorizationWidth
+        else:
+            step = node.step
+
+        counterVar = node.loopCounterName
+        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
+        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
+        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(step),)
+        loopStr = "for (%s; %s; %s)" % (start, condition, update)
+
+        prefix = "\n".join(node.prefixLines)
+        if prefix:
+            prefix += "\n"
+        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
+
+    def _print_SympyAssignment(self, node):
+        dtype = ""
+        if node.isDeclaration:
+            assert str(getBaseType(node.lhs.dtype)) in (self.dtypeName, 'bool')
+            if node.lhs.dtype == createTypeFromString(self.dtypeName):
+                dtypeStr = self._vecTypeName
+                printer = self.sympyVecPrinter
+            else:
+                dtypeStr = str(node.lhs.dtype)
+                printer = self.sympyPrinter
+
+            if node.isConst:
+                dtype = "const " + dtypeStr + " "
+            else:
+                dtype = dtypeStr + " "
+        else:
+            printer = self.sympyVecPrinter
+        return "%s %s = %s;" % (str(dtype), printer.doprint(node.lhs), printer.doprint(node.rhs))
+
+    def _print_TemporaryMemoryAllocation(self, node):
+        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
+                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
+
+    def _print_TemporaryMemoryFree(self, node):
+        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
+
+    def _print_CustomCppCode(self, node):
+        return node.code
+
+
+class CustomSympyPrinterVectorized(CCodePrinter):
+    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
+
+    def __init__(self, instructionSetInfo):
+        super(CustomSympyPrinterVectorized, self).__init__()
+        self.intrinsics = instructionSetInfo
+        self.constantsDict = {}
+
+    def getConstantsBlock(self, vecTypeStr):
+        result = ""
+        for value, symbol in self.constantsDict.items():
+            rhsStr = self.intrinsics['makeVec'].format(self._print(value))
+            result += "const %s %s = %s;\n" % (vecTypeStr, symbol.name, rhsStr)
+        return result
+
+    def _print_Add(self, expr, order=None):
+        summands = []
+        for term in expr.args:
+            if term.func == sp.Mul:
+                sign, t = self._print_Mul(term, insideAdd=True)
+            else:
+                t = self._print(term)
+                sign = 1
+            summands.append(self.SummandInfo(sign, t))
+        # Use positive terms first
+        summands.sort(key=lambda e: e.sign, reverse=True)
+        # if no positive term exists, prepend a zero
+        if summands[0].sign == -1:
+            summands.insert(0, self.SummandInfo(1, "0"))
+
+        assert len(summands) >= 2
+        processed = summands[0].term
+        for summand in summands[1:]:
+            func = self.intrinsics['-'] if summand.sign == -1 else self.intrinsics['+']
+            processed = func.format(processed, summand.term)
+        return processed
+
+    def _print_Mul(self, expr, insideAdd=False):
+
+        c, e = expr.as_coeff_Mul()
+        if c < 0:
+            expr = _keep_coeff(-c, e)
+            sign = -1
+        else:
+            sign = 1
+
+        a = []  # items in the numerator
+        b = []  # items that are in the denominator (if any)
+
+        # Gather args for numerator/denominator
+        for item in expr.as_ordered_factors():
+            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
+                if item.exp != -1:
+                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
+                else:
+                    b.append(sp.Pow(item.base, -item.exp))
+            else:
+                a.append(item)
+
+        a = a or [S.One]
+
+        a_str = [self._print(x) for x in a]
+        b_str = [self._print(x) for x in b]
+
+        result = a_str[0]
+        for item in a_str[1:]:
+            result = self.intrinsics['*'].format(result, item)
+
+        if len(b) > 0:
+            denominator_str = b_str[0]
+            for item in b_str[1:]:
+                denominator_str = self.intrinsics['*'].format(denominator_str, item)
+            result = self.intrinsics['/'].format(result, denominator_str)
+
+        if insideAdd:
+            return sign, result
+        else:
+            if sign < 0:
+                return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
+            else:
+                return result
+
+    def _print_Pow(self, expr):
+        """Don't use std::pow function, for small integer exponents, write as multiplication"""
+        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
+            return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))
+        else:
+            return super(CustomSympyPrinterVectorized, self)._print_Pow(expr)
+
+    def _print_Float(self, expr):
+        if expr not in self.constantsDict:
+            self.constantsDict[expr] = sp.Dummy()
+        symbol = self.constantsDict[expr]
+        return symbol.name
+
+    def _print_Rational(self, expr):
+        if expr not in self.constantsDict:
+            self.constantsDict[expr] = sp.Symbol("__value_%d_%d" % (expr.p, expr.q))
+        symbol = self.constantsDict[expr]
+        return symbol.name
+
+    def _print_Piecewise(self, expr):
+        if expr.args[-1].cond != True:
+            # We need the last conditional to be a True, otherwise the resulting
+            # function may not return a result.
+            raise ValueError("All Piecewise expressions must contain an "
+                             "(expr, True) statement to be used as a default "
+                             "condition. Without one, the generated "
+                             "expression may not evaluate to anything under "
+                             "some condition.")
+
+        result = self._print(expr.args[-1][0])
+        for trueExpr, condition in reversed(expr.args[:-1]):
+            result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
+        return result
+
+    def _print_Relational(self, expr):
+        return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
+
+    def _print_Equality(self, expr):
+        """Equality operator is not printable in default printer"""
+        return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
-- 
GitLab