From ea847bc5e3b820794cdc3fdb2d8affa0c36e8883 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 6 Oct 2017 16:04:25 +0200
Subject: [PATCH] Vectorization & Type system overhaul

- first vectorization tests are running
- type system: use memoized getTypeOfExpression
- casts are done using sp.Function('cast')
- C backend adapted for vectorization support
- AST nodes can required optional headers
---
 astnodes.py                       |  10 +
 backends/cbackend.py              | 174 +++++++++++++++--
 backends/cbackend_vectorized.py   | 312 ------------------------------
 backends/simd_instruction_sets.py |  91 +++++++++
 cpu/cpujit.py                     |  15 +-
 gpucuda/indexing.py               |   6 +-
 transformations.py                |  14 +-
 types.py                          | 118 ++++++++++-
 utils.py                          |   9 +
 vectorization.py                  |  94 +++++++++
 10 files changed, 504 insertions(+), 339 deletions(-)
 delete mode 100644 backends/cbackend_vectorized.py
 create mode 100644 backends/simd_instruction_sets.py
 create mode 100644 vectorization.py

diff --git a/astnodes.py b/astnodes.py
index 25d4ba4a0..16c1e9829 100644
--- a/astnodes.py
+++ b/astnodes.py
@@ -6,6 +6,8 @@ from pystencils.types import TypedSymbol, createType, get_type_from_sympy, creat
 
 class ResolvedFieldAccess(sp.Indexed):
     def __new__(cls, base, linearizedIndex, field, offsets, idxCoordinateValues):
+        if not isinstance(base, IndexedBase):
+            base = IndexedBase(base, shape=(1,))
         obj = super(ResolvedFieldAccess, cls).__new__(cls, base, linearizedIndex)
         obj.field = field
         obj.offsets = offsets
@@ -21,6 +23,14 @@ class ResolvedFieldAccess(sp.Indexed):
         superClassContents = super(ResolvedFieldAccess, self)._hashable_content()
         return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
 
+    @property
+    def typedSymbol(self):
+        return self.base.label
+
+    def __str__(self):
+        top = super(ResolvedFieldAccess, self).__str__()
+        return "%s (%s)" % (top, self.typedSymbol.dtype)
+
     def __getnewargs__(self):
         return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
 
diff --git a/backends/cbackend.py b/backends/cbackend.py
index c18a0544a..5bbb9b787 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -4,8 +4,13 @@ try:
 except ImportError:
     from sympy.printing.ccode import C99CodePrinter as CCodePrinter
 
-from pystencils.astnodes import Node
-from pystencils.types import createType, PointerType
+from collections import namedtuple
+from sympy.core.mul import _keep_coeff
+from sympy.core import S
+
+from pystencils.astnodes import Node, ResolvedFieldAccess, SympyAssignment
+from pystencils.types import createType, PointerType, getTypeOfExpression, VectorType, castFunc
+from pystencils.backends.simd_instruction_sets import selectedInstructionSet
 
 
 def generateC(astNode, signatureOnly=False):
@@ -14,10 +19,26 @@ def generateC(astNode, signatureOnly=False):
     """
     fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
     useFloatConstants = createType("double") not in fieldTypes
-    printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly)
+
+    vectorIS = selectedInstructionSet['double']
+    printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly, vectorInstructionSet=vectorIS)
     return printer(astNode)
 
 
+def getHeaders(astNode):
+    headers = set()
+
+    if hasattr(astNode, 'headers'):
+        headers.update(astNode.headers)
+    elif isinstance(astNode, SympyAssignment):
+        if type(getTypeOfExpression(astNode.rhs)) is VectorType:
+            headers.update(selectedInstructionSet['double']['headers'])
+
+    for a in astNode.args:
+        if isinstance(a, Node):
+            headers.update(getHeaders(a))
+
+    return headers
 # --------------------------------------- Backend Specific Nodes -------------------------------------------------------
 
 
@@ -26,6 +47,7 @@ class CustomCppCode(Node):
         self._code = "\n" + code
         self._symbolsRead = set(symbolsRead)
         self._symbolsDefined = set(symbolsDefined)
+        self.headers = []
 
     @property
     def code(self):
@@ -48,24 +70,33 @@ class PrintNode(CustomCppCode):
     def __init__(self, symbolToPrint):
         code = '\nstd::cout << "%s  =  " << %s << std::endl; \n' % (symbolToPrint.name, symbolToPrint.name)
         super(PrintNode, self).__init__(code, symbolsRead=[symbolToPrint], symbolsDefined=set())
+        self.headers.append("<iostream>")
 
 
 # ------------------------------------------- Printer ------------------------------------------------------------------
 
-
 class CBackend(object):
 
-    def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False):
+    def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False, vectorInstructionSet=None):
         if sympyPrinter is None:
             self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
+            if vectorInstructionSet is not None:
+                self.sympyPrinter = VectorizedCustomSympyPrinter(vectorInstructionSet, constantsAsFloats)
+            else:
+                self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
         else:
             self.sympyPrinter = sympyPrinter
 
+        self._vectorInstructionSet = vectorInstructionSet
         self._indent = "   "
         self._signatureOnly = signatureOnly
 
     def __call__(self, node):
-        return str(self._print(node))
+        prevIs = VectorType.instructionSet
+        VectorType.instructionSet = self._vectorInstructionSet
+        result = str(self._print(node))
+        VectorType.instructionSet = prevIs
+        return result
 
     def _print(self, node):
         for cls in type(node).__mro__:
@@ -103,13 +134,16 @@ class CBackend(object):
         return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
 
     def _print_SympyAssignment(self, node):
-        dtype = ""
         if node.isDeclaration:
-            if node.isConst:
-                dtype = "const " + str(node.lhs.dtype) + " "
+            dtype = "const " + str(node.lhs.dtype) + " " if node.isConst else str(node.lhs.dtype) + " "
+            return "%s %s = %s;" % (dtype, self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
+        else:
+            lhsType = getTypeOfExpression(node.lhs)
+            if type(lhsType) is VectorType and node.lhs.func == castFunc:
+                return self._vectorInstructionSet['storeU'].format("&" + self.sympyPrinter.doprint(node.lhs.args[0]),
+                                                                   self.sympyPrinter.doprint(node.rhs)) + ';'
             else:
-                dtype = str(node.lhs.dtype) + " "
-        return "%s %s = %s;" % (str(dtype), self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
+                return "%s = %s;" % (self.sympyPrinter.doprint(node.lhs), self.sympyPrinter.doprint(node.rhs))
 
     def _print_TemporaryMemoryAllocation(self, node):
         return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
@@ -177,3 +211,121 @@ class CustomSympyPrinter(CCodePrinter):
         else:
             return super(CustomSympyPrinter, self)._print_Function(expr)
 
+
+class VectorizedCustomSympyPrinter(CustomSympyPrinter):
+    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
+
+    def __init__(self, instructionSet, constantsAsFloats=False):
+        super(VectorizedCustomSympyPrinter, self).__init__(constantsAsFloats)
+        self.instructionSet = instructionSet
+
+    def _print_Function(self, expr):
+        name = str(expr.func).lower()
+        if name == 'cast':
+            arg, dtype = expr.args
+            if type(dtype) is VectorType:
+                if type(arg) is ResolvedFieldAccess:
+                    return self.instructionSet['loadU'].format("& " + self._print(arg))
+                else:
+                    return self.instructionSet['makeVec'].format(self._print(arg))
+
+        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
+
+    def _print_Add(self, expr, order=None):
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return super(VectorizedCustomSympyPrinter, self)._print_Add(expr, order)
+        assert self.instructionSet['width'] == exprType.width
+
+        summands = []
+        for term in expr.args:
+            if term.func == sp.Mul:
+                sign, t = self._print_Mul(term, insideAdd=True)
+            else:
+                t = self._print(term)
+                sign = 1
+            summands.append(self.SummandInfo(sign, t))
+        # Use positive terms first
+        summands.sort(key=lambda e: e.sign, reverse=True)
+        # if no positive term exists, prepend a zero
+        if summands[0].sign == -1:
+            summands.insert(0, self.SummandInfo(1, "0"))
+
+        assert len(summands) >= 2
+        processed = summands[0].term
+        for summand in summands[1:]:
+            func = self.instructionSet['-'] if summand.sign == -1 else self.instructionSet['+']
+            processed = func.format(processed, summand.term)
+        return processed
+
+    def _print_Mul(self, expr, insideAdd=False):
+        exprType = getTypeOfExpression(expr)
+        if type(exprType) is not VectorType:
+            return super(VectorizedCustomSympyPrinter, self)._print_Mul(expr)
+        assert self.instructionSet['width'] == exprType.width
+
+        c, e = expr.as_coeff_Mul()
+        if c < 0:
+            expr = _keep_coeff(-c, e)
+            sign = -1
+        else:
+            sign = 1
+
+        a = []  # items in the numerator
+        b = []  # items that are in the denominator (if any)
+
+        # Gather args for numerator/denominator
+        for item in expr.as_ordered_factors():
+            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
+                if item.exp != -1:
+                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
+                else:
+                    b.append(sp.Pow(item.base, -item.exp))
+            else:
+                a.append(item)
+
+        a = a or [S.One]
+
+        a_str = [self._print(x) for x in a]
+        b_str = [self._print(x) for x in b]
+
+        result = a_str[0]
+        for item in a_str[1:]:
+            result = self.intrinsics['*'].format(result, item)
+
+        if len(b) > 0:
+            denominator_str = b_str[0]
+            for item in b_str[1:]:
+                denominator_str = self.intrinsics['*'].format(denominator_str, item)
+            result = self.intrinsics['/'].format(result, denominator_str)
+
+        if insideAdd:
+            return sign, result
+        else:
+            if sign < 0:
+                return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
+            else:
+                return result
+
+#    def _print_Piecewise(self, expr):
+#        if expr.args[-1].cond != True:
+#            # We need the last conditional to be a True, otherwise the resulting
+#            # function may not return a result.
+#            raise ValueError("All Piecewise expressions must contain an "
+#                             "(expr, True) statement to be used as a default "
+#                             "condition. Without one, the generated "
+#                             "expression may not evaluate to anything under "
+#                             "some condition.")
+#
+#        result = self._print(expr.args[-1][0])
+#        for trueExpr, condition in reversed(expr.args[:-1]):
+#            result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
+#        return result
+#
+#    def _print_Relational(self, expr):
+#        return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
+#
+#    def _print_Equality(self, expr):
+#        return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
+#
+
diff --git a/backends/cbackend_vectorized.py b/backends/cbackend_vectorized.py
deleted file mode 100644
index 62a16b217..000000000
--- a/backends/cbackend_vectorized.py
+++ /dev/null
@@ -1,312 +0,0 @@
-from collections import namedtuple
-
-import sympy as sp
-from sympy.core import S
-
-try:
-    from sympy.utilities.codegen import CCodePrinter
-except ImportError:
-    from sympy.printing.ccode import C99CodePrinter as CCodePrinter
-
-from sympy.core.mul import _keep_coeff
-
-from pystencils.backends.cbackend import CustomSympyPrinter
-from pystencils.types import getBaseType, createTypeFromString
-
-
-def getInstructionSetInfoIntel(dataType='double', instructionSet='avx'):
-    baseNames = {
-        '+': 'add[0, 1]',
-        '-': 'sub[0, 1]',
-        '*': 'mul[0, 1]',
-        '/': 'div[0, 1]',
-
-        '==': 'cmp[0, 1, _CMP_EQ_UQ  ]',
-        '!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
-        '>=': 'cmp[0, 1, _CMP_GE_OQ  ]',
-        '<=': 'cmp[0, 1, _CMP_LE_OQ  ]',
-        '<': 'cmp[0, 1, _CMP_NGE_UQ ]',
-        '>': 'cmp[0, 1, _CMP_NLE_UQ ]',
-
-        'blendv': 'blendv[0, 1, 2]',
-
-        'sqrt': 'sqrt[0]',
-
-        'makeVec':  'set[0,0,0,0]',
-        'makeZero': 'setzero[]',
-
-        'loadU': 'loadu [0]',
-        'loadA': 'load [0]',
-        'storeU': 'storeu[0]',
-        'storeA': 'store [0]',
-    }
-
-    suffix = {
-        'double': 'pd',
-        'float': 'ps',
-    }
-    prefix = {
-        'sse': '_mm',
-        'avx': '_mm256',
-        'avx512': '_mm512',
-    }
-
-    width = {
-        ("double", "sse"): 2,
-        ("float", "sse"): 4,
-        ("double", "avx"): 4,
-        ("float", "avx"): 8,
-        ("double", "avx512"): 8,
-        ("float", "avx512"): 16,
-    }
-
-    result = {}
-    pre = prefix[instructionSet]
-    suf = suffix[dataType]
-    for intrinsicId, functionShortcut in baseNames.items():
-        functionShortcut = functionShortcut.strip()
-        name = functionShortcut[:functionShortcut.index('[')]
-        args = functionShortcut[functionShortcut.index('[') + 1: -1]
-        argString = "("
-        for arg in args.split(","):
-            arg = arg.strip()
-            if not arg:
-                continue
-            if arg in ('0', '1', '2', '3', '4', '5'):
-                argString += "{" + arg + "},"
-            else:
-                argString += arg
-        argString = argString[:-1] + ")"
-        result[intrinsicId] = pre + "_" + name + "_" + suf + argString
-
-    result['width'] = width[(dataType, instructionSet)]
-    result['dataTypePrefix'] = {
-        'double': "_" + pre + 'd',
-        'float': "_" + pre,
-    }
-
-    return result
-
-
-class VectorizedCBackend(object):
-
-    def __init__(self, astNode, instructionSet='avx'):
-        fieldTypes = set([getBaseType(f.dtype) for f in astNode.fieldsAccessed])
-        if len(fieldTypes) != 1:
-            raise ValueError("Vectorized backend does not support kernels with mixed field types")
-        fieldType = fieldTypes.pop()
-        assert fieldType.is_float
-        dtypeName = str(fieldType)
-
-        instructionSetInfo = getInstructionSetInfoIntel(dtypeName, instructionSet)
-
-        self.vectorizationWidth = instructionSetInfo['width']
-        self.sympyVecPrinter = CustomSympyPrinterVectorized(instructionSetInfo)
-        self.sympyPrinter = CustomSympyPrinter(constantsAsFloats=(dtypeName == 'float'))
-
-        self._indent = "   "
-        self._vecTypeName = instructionSetInfo['dataTypePrefix'][dtypeName]
-        self.dtypeName = dtypeName
-
-    def __call__(self, node):
-        return str(self._print(node))
-
-    def _print(self, node):
-        for cls in type(node).__mro__:
-            methodName = "_print_" + cls.__name__
-            if hasattr(self, methodName):
-                return getattr(self, methodName)(node)
-        raise NotImplementedError("CBackend does not support node of type " + cls.__name__)
-
-    def _print_KernelFunction(self, node):
-        blockContents = "\n".join([self._print(child) for child in node.body.args])
-        constantBlock = self.sympyVecPrinter.getConstantsBlock(self._vecTypeName)
-
-        body = "{\n%s\n%s\n}" % (constantBlock, self._indent + self._indent.join(blockContents.splitlines(True)))
-
-        functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
-        funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
-        return funcDeclaration + "\n" + body
-
-    def _print_Block(self, node):
-        blockContents = "\n".join([self._print(child) for child in node.args])
-        return "{\n%s\n}" % (self._indent + self._indent.join(blockContents.splitlines(True)),)
-
-    def _print_PragmaBlock(self, node):
-        return "%s\n%s" % (node.pragmaLine, self._print_Block(node))
-
-    def _print_LoopOverCoordinate(self, node):
-        if node.isInnermostLoop:
-            iterRange = node.stop - node.start
-            if isinstance(iterRange, sp.Basic) and not iterRange.is_integer:
-                raise NotImplementedError("Vectorized backend currently only supports fixed size inner loops")
-            if iterRange % self.vectorizationWidth != 0 or node.step != 1:
-                raise NotImplementedError("Vectorized backend only supports loop bounds that are "
-                                          "multiples of vectorization width")
-            step = self.vectorizationWidth
-        else:
-            step = node.step
-
-        counterVar = node.loopCounterName
-        start = "int %s = %s" % (counterVar, self.sympyPrinter.doprint(node.start))
-        condition = "%s < %s" % (counterVar, self.sympyPrinter.doprint(node.stop))
-        update = "%s += %s" % (counterVar, self.sympyPrinter.doprint(step),)
-        loopStr = "for (%s; %s; %s)" % (start, condition, update)
-
-        prefix = "\n".join(node.prefixLines)
-        if prefix:
-            prefix += "\n"
-        return "%s%s\n%s" % (prefix, loopStr, self._print(node.body))
-
-    def _print_SympyAssignment(self, node):
-        dtype = ""
-        if node.isDeclaration:
-            assert str(getBaseType(node.lhs.dtype)) in (self.dtypeName, 'bool')
-            if node.lhs.dtype == createTypeFromString(self.dtypeName):
-                dtypeStr = self._vecTypeName
-                printer = self.sympyVecPrinter
-            else:
-                dtypeStr = str(node.lhs.dtype)
-                printer = self.sympyPrinter
-
-            if node.isConst:
-                dtype = "const " + dtypeStr + " "
-            else:
-                dtype = dtypeStr + " "
-        else:
-            printer = self.sympyVecPrinter
-        return "%s %s = %s;" % (str(dtype), printer.doprint(node.lhs), printer.doprint(node.rhs))
-
-    def _print_TemporaryMemoryAllocation(self, node):
-        return "%s * %s = new %s[%s];" % (node.symbol.dtype, self.sympyPrinter.doprint(node.symbol),
-                                         node.symbol.dtype, self.sympyPrinter.doprint(node.size))
-
-    def _print_TemporaryMemoryFree(self, node):
-        return "delete [] %s;" % (self.sympyPrinter.doprint(node.symbol),)
-
-    def _print_CustomCppCode(self, node):
-        return node.code
-
-
-class CustomSympyPrinterVectorized(CCodePrinter):
-    SummandInfo = namedtuple("SummandInfo", ['sign', 'term'])
-
-    def __init__(self, instructionSetInfo):
-        super(CustomSympyPrinterVectorized, self).__init__()
-        self.intrinsics = instructionSetInfo
-        self.constantsDict = {}
-
-    def getConstantsBlock(self, vecTypeStr):
-        result = ""
-        for value, symbol in self.constantsDict.items():
-            rhsStr = self.intrinsics['makeVec'].format(self._print(value))
-            result += "const %s %s = %s;\n" % (vecTypeStr, symbol.name, rhsStr)
-        return result
-
-    def _print_Add(self, expr, order=None):
-        summands = []
-        for term in expr.args:
-            if term.func == sp.Mul:
-                sign, t = self._print_Mul(term, insideAdd=True)
-            else:
-                t = self._print(term)
-                sign = 1
-            summands.append(self.SummandInfo(sign, t))
-        # Use positive terms first
-        summands.sort(key=lambda e: e.sign, reverse=True)
-        # if no positive term exists, prepend a zero
-        if summands[0].sign == -1:
-            summands.insert(0, self.SummandInfo(1, "0"))
-
-        assert len(summands) >= 2
-        processed = summands[0].term
-        for summand in summands[1:]:
-            func = self.intrinsics['-'] if summand.sign == -1 else self.intrinsics['+']
-            processed = func.format(processed, summand.term)
-        return processed
-
-    def _print_Mul(self, expr, insideAdd=False):
-
-        c, e = expr.as_coeff_Mul()
-        if c < 0:
-            expr = _keep_coeff(-c, e)
-            sign = -1
-        else:
-            sign = 1
-
-        a = []  # items in the numerator
-        b = []  # items that are in the denominator (if any)
-
-        # Gather args for numerator/denominator
-        for item in expr.as_ordered_factors():
-            if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
-                if item.exp != -1:
-                    b.append(sp.Pow(item.base, -item.exp, evaluate=False))
-                else:
-                    b.append(sp.Pow(item.base, -item.exp))
-            else:
-                a.append(item)
-
-        a = a or [S.One]
-
-        a_str = [self._print(x) for x in a]
-        b_str = [self._print(x) for x in b]
-
-        result = a_str[0]
-        for item in a_str[1:]:
-            result = self.intrinsics['*'].format(result, item)
-
-        if len(b) > 0:
-            denominator_str = b_str[0]
-            for item in b_str[1:]:
-                denominator_str = self.intrinsics['*'].format(denominator_str, item)
-            result = self.intrinsics['/'].format(result, denominator_str)
-
-        if insideAdd:
-            return sign, result
-        else:
-            if sign < 0:
-                return self.intrinsics['*'].format(self._print(S.NegativeOne), result)
-            else:
-                return result
-
-    def _print_Pow(self, expr):
-        """Don't use std::pow function, for small integer exponents, write as multiplication"""
-        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
-            return self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))
-        else:
-            return super(CustomSympyPrinterVectorized, self)._print_Pow(expr)
-
-    def _print_Float(self, expr):
-        if expr not in self.constantsDict:
-            self.constantsDict[expr] = sp.Dummy()
-        symbol = self.constantsDict[expr]
-        return symbol.name
-
-    def _print_Rational(self, expr):
-        if expr not in self.constantsDict:
-            self.constantsDict[expr] = sp.Symbol("__value_%d_%d" % (expr.p, expr.q))
-        symbol = self.constantsDict[expr]
-        return symbol.name
-
-    def _print_Piecewise(self, expr):
-        if expr.args[-1].cond != True:
-            # We need the last conditional to be a True, otherwise the resulting
-            # function may not return a result.
-            raise ValueError("All Piecewise expressions must contain an "
-                             "(expr, True) statement to be used as a default "
-                             "condition. Without one, the generated "
-                             "expression may not evaluate to anything under "
-                             "some condition.")
-
-        result = self._print(expr.args[-1][0])
-        for trueExpr, condition in reversed(expr.args[:-1]):
-            result = self.intrinsics['blendv'].format(result, self._print(trueExpr), self._print(condition))
-        return result
-
-    def _print_Relational(self, expr):
-        return self.intrinsics[expr.rel_op].format(expr.lhs, expr.rhs)
-
-    def _print_Equality(self, expr):
-        """Equality operator is not printable in default printer"""
-        return self.intrinsics['=='].format(self._print(expr.lhs), self._print(expr.rhs))
diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py
new file mode 100644
index 000000000..213b4c481
--- /dev/null
+++ b/backends/simd_instruction_sets.py
@@ -0,0 +1,91 @@
+
+
+def x86VectorInstructionSet(dataType='double', instructionSet='avx'):
+    baseNames = {
+        '+': 'add[0, 1]',
+        '-': 'sub[0, 1]',
+        '*': 'mul[0, 1]',
+        '/': 'div[0, 1]',
+
+        '==': 'cmp[0, 1, _CMP_EQ_UQ  ]',
+        '!=': 'cmp[0, 1, _CMP_NEQ_UQ ]',
+        '>=': 'cmp[0, 1, _CMP_GE_OQ  ]',
+        '<=': 'cmp[0, 1, _CMP_LE_OQ  ]',
+        '<': 'cmp[0, 1, _CMP_NGE_UQ ]',
+        '>': 'cmp[0, 1, _CMP_NLE_UQ ]',
+
+        'blendv': 'blendv[0, 1, 2]',
+
+        'sqrt': 'sqrt[0]',
+
+        'makeVec':  'set[0,0,0,0]',
+        'makeZero': 'setzero[]',
+
+        'loadU': 'loadu[0]',
+        'loadA': 'load[0]',
+        'storeU': 'storeu[0,1]',
+        'storeA': 'store [0,1]',
+    }
+
+    headers = {
+        'avx': ['<immintrin.h>'],
+        'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
+    }
+
+    suffix = {
+        'double': 'pd',
+        'float': 'ps',
+    }
+    prefix = {
+        'sse': '_mm',
+        'avx': '_mm256',
+        'avx512': '_mm512',
+    }
+
+    width = {
+        ("double", "sse"): 2,
+        ("float", "sse"): 4,
+        ("double", "avx"): 4,
+        ("float", "avx"): 8,
+        ("double", "avx512"): 8,
+        ("float", "avx512"): 16,
+    }
+
+    result = {}
+    pre = prefix[instructionSet]
+    suf = suffix[dataType]
+    for intrinsicId, functionShortcut in baseNames.items():
+        functionShortcut = functionShortcut.strip()
+        name = functionShortcut[:functionShortcut.index('[')]
+        args = functionShortcut[functionShortcut.index('[') + 1: -1]
+        argString = "("
+        for arg in args.split(","):
+            arg = arg.strip()
+            if not arg:
+                continue
+            if arg in ('0', '1', '2', '3', '4', '5'):
+                argString += "{" + arg + "},"
+            else:
+                argString += arg
+        argString = argString[:-1] + ")"
+        result[intrinsicId] = pre + "_" + name + "_" + suf + argString
+
+    result['width'] = width[(dataType, instructionSet)]
+    result['dataTypePrefix'] = {
+        'double': "_" + pre + 'd',
+        'float': "_" + pre,
+    }
+
+    bitWidth = result['width'] * 64
+    result['double'] = "__m%dd" % (bitWidth,)
+    result['float'] = "__m%d" % (bitWidth,)
+    result['int'] = "__m%di" % (bitWidth,)
+
+    result['headers'] = headers[instructionSet]
+    return result
+
+
+selectedInstructionSet = {
+    'float': x86VectorInstructionSet('float', 'avx'),
+    'double': x86VectorInstructionSet('double', 'avx'),
+}
diff --git a/cpu/cpujit.py b/cpu/cpujit.py
index f853fa51b..c5919fc6c 100644
--- a/cpu/cpujit.py
+++ b/cpu/cpujit.py
@@ -33,7 +33,7 @@ Then 'cl.exe' is used to compile.
                       where Visual Studio is installed. This path has to contain a file called 'vcvarsall.bat'
 - **'arch'**: 'x86' or 'x64'
 - **'flags'**: flags passed to 'cl.exe', make sure OpenMP is activated
-- **'restrictQualifier'**: the restrict qualifier is not standardized accross compilers.
+- **'restrictQualifier'**: the restrict qualifier is not standardized across compilers.
   For Windows compilers the qualifier should be ``__restrict``
 
 
@@ -70,7 +70,7 @@ import glob
 import atexit
 import shutil
 from ctypes import cdll
-from pystencils.backends.cbackend import generateC
+from pystencils.backends.cbackend import generateC, getHeaders
 from collections import OrderedDict, Mapping
 from pystencils.transformations import symbolNameToVariableName
 from pystencils.types import toCtypes, getBaseType, StructType
@@ -276,10 +276,13 @@ def compileObjectCacheToSharedLibrary():
 atexit.register(compileObjectCacheToSharedLibrary)
 
 
-def generateCode(ast, includes, restrictQualifier, functionPrefix, targetFile):
+def generateCode(ast, restrictQualifier, functionPrefix, targetFile):
+    headers = getHeaders(ast)
+    headers.update(['<cmath>', '<cstdint>'])
+
     with open(targetFile, 'w') as sourceFile:
         code = generateC(ast)
-        includes = "\n".join(["#include <%s>" % (includeFile,) for includeFile in includes])
+        includes = "\n".join(["#include %s" % (includeFile,) for includeFile in headers])
         print(includes, file=sourceFile)
         print("#define RESTRICT %s" % (restrictQualifier,), file=sourceFile)
         print("#define FUNC_PREFIX %s" % (functionPrefix,), file=sourceFile)
@@ -310,7 +313,7 @@ def compileLinux(ast, codeHashStr, srcFile, libFile):
     objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.o')
     # Compilation
     if not os.path.exists(objectFile):
-        generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'], '', srcFile)
+        generateCode(ast, compilerConfig['restrictQualifier'], '', srcFile)
         compileCmd = [compilerConfig['command'], '-c'] + compilerConfig['flags'].split()
         compileCmd += ['-o', objectFile, srcFile]
         runCompileStep(compileCmd)
@@ -326,7 +329,7 @@ def compileWindows(ast, codeHashStr, srcFile, libFile):
     objectFile = os.path.join(cacheConfig['objectCache'], codeHashStr + '.obj')
     # Compilation
     if not os.path.exists(objectFile):
-        generateCode(ast, ['iostream', 'cmath', 'cstdint'], compilerConfig['restrictQualifier'],
+        generateCode(ast, compilerConfig['restrictQualifier'],
                      '__declspec(dllexport)', srcFile)
 
         # /c compiles only, /EHsc turns of exception handling in c code
diff --git a/gpucuda/indexing.py b/gpucuda/indexing.py
index 8291ebf4a..9c7f1be6b 100644
--- a/gpucuda/indexing.py
+++ b/gpucuda/indexing.py
@@ -1,15 +1,15 @@
 import abc
 
 import sympy as sp
-import math
 import pycuda.driver as cuda
 import pycuda.autoinit
 
 from pystencils.astnodes import Conditional, Block
 from pystencils.slicing import normalizeSlice
+from pystencils.types import TypedSymbol, createTypeFromString
 
-BLOCK_IDX = list(sp.symbols("blockIdx.x blockIdx.y blockIdx.z"))
-THREAD_IDX = list(sp.symbols("threadIdx.x threadIdx.y threadIdx.z"))
+BLOCK_IDX = [TypedSymbol("blockIdx." + coord, createTypeFromString("int")) for coord in ('x', 'y', 'z')]
+THREAD_IDX = [TypedSymbol("threadIdx." + coord, createTypeFromString("int")) for coord in ('x', 'y', 'z')]
 
 
 class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})):
diff --git a/transformations.py b/transformations.py
index 0975e96aa..38c6717b0 100644
--- a/transformations.py
+++ b/transformations.py
@@ -7,11 +7,18 @@ from sympy.logic.boolalg import Boolean
 from sympy.tensor import IndexedBase
 
 from pystencils.field import Field, offsetComponentToDirectionString
-from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, createTypeFromString
+from pystencils.types import TypedSymbol, createType, PointerType, StructType, getBaseType, castFunc
 from pystencils.slicing import normalizeSlice
 import pystencils.astnodes as ast
 
 
+def filteredTreeIteration(node, nodeType):
+    for arg in node.args:
+        if isinstance(arg, nodeType):
+            yield arg
+        yield from filteredTreeIteration(arg, nodeType)
+
+
 def fastSubs(term, subsDict):
     """Similar to sympy subs function.
     This version is much faster for big substitution dictionaries than sympy version"""
@@ -332,9 +339,8 @@ def resolveFieldAccesses(astNode, readOnlyFieldNames=set(), fieldToBasePointerIn
 
             coordDict = createCoordinateDict(basePointerInfo[0])
             _, offset = createIntermediateBasePointer(fieldAccess, coordDict, lastPointer)
-            baseArr = IndexedBase(lastPointer, shape=(1,))
-            result = ast.ResolvedFieldAccess(baseArr, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index)
-            castFunc = sp.Function("cast")
+            result = ast.ResolvedFieldAccess(lastPointer, offset, fieldAccess.field, fieldAccess.offsets, fieldAccess.index)
+
             if isinstance(getBaseType(fieldAccess.field.dtype), StructType):
                 newType = fieldAccess.field.dtype.getElementType(fieldAccess.index[0])
                 result = castFunc(result, newType)
diff --git a/types.py b/types.py
index 86ad051b3..23c487501 100644
--- a/types.py
+++ b/types.py
@@ -2,7 +2,11 @@ import ctypes
 import sympy as sp
 import numpy as np
 from sympy.core.cache import cacheit
+
 from pystencils.cache import memorycache
+from pystencils.utils import allEqual
+
+castFunc = sp.Function("cast")
 
 
 class TypedSymbol(sp.Symbol):
@@ -28,7 +32,7 @@ class TypedSymbol(sp.Symbol):
 
     def _hashable_content(self):
         superClassContents = list(super(TypedSymbol, self)._hashable_content())
-        return tuple(superClassContents + [hash(repr(self._dtype))])
+        return tuple(superClassContents + [hash(str(self._dtype))])
 
     def __getnewargs__(self):
         return self.name, self.dtype
@@ -52,6 +56,7 @@ def createType(specification):
             return StructType(npDataType, const=False)
 
 
+@memorycache(maxsize=64)
 def createTypeFromString(specification):
     """
     Creates a new Type object from a c-like string specification
@@ -131,10 +136,79 @@ toCtypes.map = {
 }
 
 
+def peelOffType(dtype, typeToPeelOff):
+    while type(dtype) is typeToPeelOff:
+        dtype = dtype.baseType
+    return dtype
+
+
+def collateTypes(types):
+    """
+    Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
+    Uses the collation rules from numpy.
+    """
+
+    # Pointer arithmetic case i.e. pointer + integer is allowed
+    if any(type(t) is PointerType for t in types):
+        pointerType = None
+        for t in types:
+            if type(t) is PointerType:
+                if pointerType is not None:
+                    raise ValueError("Cannot collate the combination of two pointer types")
+                pointerType = t
+            elif type(t) is BasicType:
+                if not (t.is_int() or t.is_uint()):
+                    raise ValueError("Invalid pointer arithmetic")
+            else:
+                raise ValueError("Invalid pointer arithmetic")
+        return pointerType
+
+    # peel of vector types, if at least one vector type occurred the result will also be the vector type
+    vectorType = [t for t in types if type(t) is VectorType]
+    if not allEqual(t.width for t in vectorType):
+        raise ValueError("Collation failed because of vector types with different width")
+    types = [peelOffType(t, VectorType) for t in types]
+
+    # now we should have a list of basic types - struct types are not yet supported
+    assert all(type(t) is BasicType for t in types)
+
+    # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
+    resultNumpyType = np.result_type(*(t.numpyDtype for t in types))
+    result = BasicType(resultNumpyType)
+    if vectorType:
+        result = VectorType(result, vectorType[0].width)
+    return result
+
+
+@memorycache(maxsize=2048)
 def getTypeOfExpression(expr):
-    if isinstance(expr, TypedSymbol):
+    from pystencils.astnodes import ResolvedFieldAccess
+    expr = sp.sympify(expr)
+    if isinstance(expr, sp.Integer):
+        return createTypeFromString("int")
+    elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
+        return createTypeFromString("double")
+    elif isinstance(expr, ResolvedFieldAccess):
+        return expr.field.dtype
+    elif isinstance(expr, TypedSymbol):
         return expr.dtype
-
+    elif isinstance(expr, sp.Symbol):
+        raise ValueError("All symbols inside this expression have to be typed!")
+    elif hasattr(expr, 'func') and expr.func == castFunc:
+        return expr.args[1]
+    elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
+        branchResults = [a[0] for a in expr.args]
+        return collateTypes(tuple(getTypeOfExpression(a) for a in branchResults))
+    elif isinstance(expr, sp.Indexed):
+        typedSymbol = expr.base.label
+        return typedSymbol.dtype
+    elif isinstance(expr, sp.Expr):
+        types = tuple(getTypeOfExpression(a) for a in expr.args)
+        return collateTypes(types)
+    elif isinstance(expr, sp.boolalg.Boolean):
+        return createTypeFromString("bool")
+
+    raise NotImplementedError("Could not determine type for " + str(expr))
 
 
 class Type(sp.Basic):
@@ -239,6 +313,44 @@ class BasicType(Type):
         return hash(str(self))
 
 
+class VectorType(Type):
+    instructionSet = None
+
+    def __init__(self, baseType, width=4):
+        self._baseType = baseType
+        self.width = width
+
+    @property
+    def baseType(self):
+        return self._baseType
+
+    @property
+    def itemSize(self):
+        return self.width * self.baseType.itemSize
+
+    def __eq__(self, other):
+        if not isinstance(other, VectorType):
+            return False
+        else:
+            return (self.baseType, self.width) == (other.baseType, other.width)
+
+    def __str__(self):
+        if self.instructionSet is None:
+            return "%s[%d]" % (self.baseType, self.width)
+        else:
+            if self.baseType == createTypeFromString("int64"):
+                return self.instructionSet['int']
+            elif self.baseType == createTypeFromString("double"):
+                return self.instructionSet['double']
+            elif self.baseType == createTypeFromString("float"):
+                return self.instructionSet['float']
+            else:
+                raise NotImplementedError()
+
+    def __hash__(self):
+        return hash(str(self))
+
+
 class PointerType(Type):
     def __init__(self, baseType, const=False, restrict=True):
         self._baseType = baseType
diff --git a/utils.py b/utils.py
index a48a963fa..2ea6b7dd9 100644
--- a/utils.py
+++ b/utils.py
@@ -4,3 +4,12 @@ class DotDict(dict):
     __getattr__ = dict.get
     __setattr__ = dict.__setitem__
     __delattr__ = dict.__delitem__
+
+
+def allEqual(iterator):
+    iterator = iter(iterator)
+    try:
+        first = next(iterator)
+    except StopIteration:
+        return True
+    return all(first == rest for rest in iterator)
diff --git a/vectorization.py b/vectorization.py
new file mode 100644
index 000000000..54a9819f7
--- /dev/null
+++ b/vectorization.py
@@ -0,0 +1,94 @@
+import sympy as sp
+import warnings
+
+from pystencils.transformations import filteredTreeIteration
+from pystencils.types import TypedSymbol, VectorType, PointerType, BasicType, getTypeOfExpression, castFunc
+import pystencils.astnodes as ast
+from pystencils.utils import allEqual
+
+
+def asVectorType(resolvedFieldAccess, vectorizationWidth):
+    """Returns a new ResolvedFieldAccess that has a vector type"""
+    dtype = resolvedFieldAccess.typedSymbol.dtype
+    assert type(dtype) is PointerType
+    basicType = dtype.baseType
+    assert type(basicType) is BasicType, "Structs are not supported"
+
+    newDtype = VectorType(basicType, vectorizationWidth)
+    newDtype = PointerType(newDtype, dtype.const, dtype.restrict)
+    newTypedSymbol = TypedSymbol(resolvedFieldAccess.typedSymbol.name, newDtype)
+    return ast.ResolvedFieldAccess(newTypedSymbol, resolvedFieldAccess.args[1], resolvedFieldAccess.field,
+                                   resolvedFieldAccess.offsets, resolvedFieldAccess.idxCoordinateValues)
+
+
+def vectorize(astNode, vectorWidth=4):
+    """
+    Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type if
+        - loop bounds are constant
+        - loop range is a multiple of vector width
+    """
+    innerLoops = [n for n in astNode.atoms(ast.LoopOverCoordinate) if n.isInnermostLoop]
+
+    for loopNode in innerLoops:
+        loopRange = loopNode.stop - loopNode.start
+
+        # Check restrictions
+        if isinstance(loopRange, sp.Basic) and not loopRange.is_integer:
+            warnings.warn("Currently only loops with fixed ranges can be vectorized - skipping loop")
+            continue
+        if loopRange % vectorWidth != 0 or loopNode.step != 1:
+            warnings.warn("Currently only loops with loop bounds that are multiples "
+                          "of vectorization width can be vectorized")
+            continue
+
+        loopNode.step = vectorWidth
+
+        # All field accesses depending on loop coordinate are changed to vector type
+        fieldAccesses = [n for n in loopNode.atoms(ast.ResolvedFieldAccess)]
+        substitutions = {fa: castFunc(fa, VectorType(BasicType(fa.field.dtype), vectorWidth)) for fa in fieldAccesses}
+        loopNode.subs(substitutions)
+
+
+def insertVectorCasts(astNode):
+    """
+    Inserts necessary casts from scalar values to vector values
+    """
+    def visitExpr(expr):
+        if expr.func in (sp.Add, sp.Mul):
+            newArgs = [visitExpr(a) for a in expr.args]
+            argTypes = [getTypeOfExpression(a) for a in newArgs]
+            if not any(type(t) is VectorType for t in argTypes):
+                return expr
+            else:
+                vectorWidths = [d.width for d in argTypes if type(d) is VectorType]
+                assert allEqual(vectorWidths), "Incompatible vector type widths"
+                vectorWidth = vectorWidths[0]
+                castedArgs = [castFunc(a, VectorType(t, vectorWidth)) if type(t) is not VectorType else a
+                              for a, t in zip(newArgs, argTypes)]
+                return expr.func(*castedArgs)
+        elif expr.func == sp.Piecewise:
+            raise NotImplementedError()
+        else:
+            return expr
+
+    substitutionDict = {}
+    for asmt in filteredTreeIteration(astNode, ast.SympyAssignment):
+        subsExpr = asmt.rhs.subs(substitutionDict)
+        asmt.rhs = visitExpr(subsExpr)
+        rhsType = getTypeOfExpression(asmt.rhs)
+        if isinstance(asmt.lhs, TypedSymbol):
+            lhsType = asmt.lhs.dtype
+            if type(rhsType) is VectorType and type(lhsType) is not VectorType:
+                newLhsType = VectorType(lhsType, rhsType.width)
+                newLhs = TypedSymbol(asmt.lhs.name, newLhsType)
+                substitutionDict[asmt.lhs] = newLhs
+                asmt.lhs = newLhs
+        elif asmt.lhs.func == castFunc:
+            lhsType = asmt.lhs.args[1]
+            if type(lhsType) is VectorType and type(rhsType) is not VectorType:
+                asmt.rhs = castFunc(asmt.rhs, lhsType)
+
+
+
+
+
-- 
GitLab