From 866e9fc0019e733191c37533ae46ae0f3d366c29 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 14 May 2018 15:51:38 +0200
Subject: [PATCH] Fixes in vectorization to also support float kernels

---
 backends/cbackend.py              | 25 +++++++++++++--------
 backends/simd_instruction_sets.py | 36 ++++++++++++++++++-------------
 cpu/vectorization.py              |  6 +++---
 data_types.py                     |  7 +++++-
 kernelcreation.py                 |  2 +-
 llvm/llvm.py                      |  7 ++++++
 transformations.py                | 26 ++++++++++++++++------
 7 files changed, 73 insertions(+), 36 deletions(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 188679d47..76eb3ee74 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -213,6 +213,7 @@ class CustomSympyPrinter(CCodePrinter):
     def __init__(self, constants_as_floats=False):
         self._constantsAsFloats = constants_as_floats
         super(CustomSympyPrinter, self).__init__()
+        self._float_type = create_type("float32")
 
     def _print_Pow(self, expr):
         """Don't use std::pow function, for small integer exponents, write as multiplication"""
@@ -224,8 +225,6 @@ class CustomSympyPrinter(CCodePrinter):
     def _print_Rational(self, expr):
         """Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
         res = str(expr.evalf().num)
-        if self._constantsAsFloats:
-            res += "f"
         return res
 
     def _print_Equality(self, expr):
@@ -237,12 +236,6 @@ class CustomSympyPrinter(CCodePrinter):
         result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
         return result.replace("\n", "")
 
-    def _print_Float(self, expr):
-        res = str(expr)
-        if self._constantsAsFloats:
-            res += "f"
-        return res
-
     def _print_Function(self, expr):
         function_map = {
             bitwise_xor: '^',
@@ -255,7 +248,10 @@ class CustomSympyPrinter(CCodePrinter):
             return expr.to_c(self._print)
         if expr.func == cast_func:
             arg, data_type = expr.args
-            return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
+            if isinstance(arg, sp.Number):
+                return self._typed_number(arg, data_type)
+            else:
+                return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
         elif expr.func == modulo_floor:
             assert all(get_type_of_expression(e).is_int() for e in expr.args)
             return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0]))
@@ -264,6 +260,17 @@ class CustomSympyPrinter(CCodePrinter):
         else:
             return super(CustomSympyPrinter, self)._print_Function(expr)
 
+    def _typed_number(self, number, dtype):
+        res = self._print(number)
+        if dtype.is_float:
+            if dtype == self._float_type:
+                if '.' not in res:
+                    res += ".0f"
+                else:
+                    res += "f"
+            return res
+        else:
+            return res
 
 # noinspection PyPep8Naming
 class VectorizedCustomSympyPrinter(CustomSympyPrinter):
diff --git a/backends/simd_instruction_sets.py b/backends/simd_instruction_sets.py
index 518e6a59a..28e1cc8b2 100644
--- a/backends/simd_instruction_sets.py
+++ b/backends/simd_instruction_sets.py
@@ -20,7 +20,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
 
         'sqrt': 'sqrt[0]',
 
-        'makeVec': 'set[0,0,0,0]',
+        'makeVec': 'set[]',
         'makeZero': 'setzero[]',
 
         'loadU': 'loadu[0]',
@@ -31,6 +31,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
     }
 
     headers = {
+        'avx512': ['<immintrin.h>'],
         'avx': ['<immintrin.h>'],
         'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
     }
@@ -54,32 +55,37 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
         ("float", "avx512"): 16,
     }
 
-    result = {}
+    result = {
+        'width': width[(data_type, instruction_set)],
+    }
     pre = prefix[instruction_set]
     suf = suffix[data_type]
     for intrinsic_id, function_shortcut in base_names.items():
         function_shortcut = function_shortcut.strip()
         name = function_shortcut[:function_shortcut.index('[')]
-        args = function_shortcut[function_shortcut.index('[') + 1: -1]
-        arg_string = "("
-        for arg in args.split(","):
-            arg = arg.strip()
-            if not arg:
-                continue
-            if arg in ('0', '1', '2', '3', '4', '5'):
-                arg_string += "{" + arg + "},"
-            else:
-                arg_string += arg + ","
-        arg_string = arg_string[:-1] + ")"
+
+        if intrinsic_id == 'makeVec':
+            arg_string = "({})".format(",".join(["{0}"] * result['width']))
+        else:
+            args = function_shortcut[function_shortcut.index('[') + 1: -1]
+            arg_string = "("
+            for arg in args.split(","):
+                arg = arg.strip()
+                if not arg:
+                    continue
+                if arg in ('0', '1', '2', '3', '4', '5'):
+                    arg_string += "{" + arg + "},"
+                else:
+                    arg_string += arg + ","
+            arg_string = arg_string[:-1] + ")"
         result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string
 
-    result['width'] = width[(data_type, instruction_set)]
     result['dataTypePrefix'] = {
         'double': "_" + pre + 'd',
         'float': "_" + pre,
     }
 
-    bit_width = result['width'] * 64
+    bit_width = result['width'] * (64 if data_type == 'double' else 32)
     result['double'] = "__m%dd" % (bit_width,)
     result['float'] = "__m%d" % (bit_width,)
     result['int'] = "__m%di" % (bit_width,)
diff --git a/cpu/vectorization.py b/cpu/vectorization.py
index 3745bb31c..a8542d0af 100644
--- a/cpu/vectorization.py
+++ b/cpu/vectorization.py
@@ -13,13 +13,13 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration
 from pystencils.field import Field
 
 
-def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx',
+def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
               assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False):
     """Explicit vectorization using SIMD vectorization via intrinsics.
 
     Args:
         kernel_ast: abstract syntax tree (KernelFunction node)
-        vector_instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
+        instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
         assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
                         used. If true, some of the loads are assumed to be from aligned memory addresses.
                         For example if x is the fastest coordinate, the access to center can be fetched via an
@@ -42,7 +42,7 @@ def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx
     float_size = field_float_dtypes.pop().numpy_dtype.itemsize
     assert float_size in (8, 4)
     vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
-                                           instruction_set=vector_instruction_set)
+                                           instruction_set=instruction_set)
     vector_width = vector_is['width']
     kernel_ast.instruction_set = vector_is
 
diff --git a/data_types.py b/data_types.py
index 318a24752..d3dad765d 100644
--- a/data_types.py
+++ b/data_types.py
@@ -289,7 +289,10 @@ def get_type_of_expression(expr):
     from pystencils.astnodes import ResolvedFieldAccess
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
-        return create_type("int")
+        if expr == 1 or expr == -1:
+            return create_type("int16")
+        else:
+            return create_type("int")
     elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
         return create_type("double")
     elif isinstance(expr, ResolvedFieldAccess):
@@ -316,6 +319,8 @@ def get_type_of_expression(expr):
         if vec_args:
             result = VectorType(result, width=vec_args[0].width)
         return result
+    elif isinstance(expr, sp.Pow):
+        return get_type_of_expression(expr.args[0])
     elif isinstance(expr, sp.Expr):
         types = tuple(get_type_of_expression(a) for a in expr.args)
         return collate_types(types)
diff --git a/kernelcreation.py b/kernelcreation.py
index f79865f2e..0a948f195 100644
--- a/kernelcreation.py
+++ b/kernelcreation.py
@@ -73,7 +73,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
             add_openmp(ast, num_threads=cpu_openmp)
         if cpu_vectorize_info:
             if cpu_vectorize_info is True:
-                vectorize(ast, vector_instruction_set='avx', assume_aligned=False, nontemporal=None)
+                vectorize(ast, instruction_set='avx', assume_aligned=False, nontemporal=None)
             elif isinstance(cpu_vectorize_info, dict):
                 vectorize(ast, **cpu_vectorize_info)
             else:
diff --git a/llvm/llvm.py b/llvm/llvm.py
index b6a0f5895..1b165a2f9 100644
--- a/llvm/llvm.py
+++ b/llvm/llvm.py
@@ -205,10 +205,17 @@ class LLVMPrinter(Printer):
         node = self._print(conversion.args[0])
         to_dtype = get_type_of_expression(conversion)
         from_dtype = get_type_of_expression(conversion.args[0])
+        if from_dtype == to_dtype:
+            return self._print(conversion.args[0])
+
         # (From, to)
         decision = {
+            (create_composite_type_from_string("int16"),
+             create_composite_type_from_string("int64")): lambda: ir.Constant(self.integer, node),
             (create_composite_type_from_string("int"),
              create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
+            (create_composite_type_from_string("int16"),
+             create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
             (create_composite_type_from_string("double"),
              create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer),
             (create_composite_type_from_string("double *"),
diff --git a/transformations.py b/transformations.py
index b2c6f5e27..ab7d98ab5 100644
--- a/transformations.py
+++ b/transformations.py
@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
 from pystencils.assignment import Assignment
 from pystencils.field import Field, FieldType
 from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
-    pointer_arithmetic_func, get_type_of_expression, collate_types
+    pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
 from pystencils.slicing import normalize_slice
 import pystencils.astnodes as ast
 
@@ -716,9 +716,18 @@ class KernelConstraintsCheck:
             return rhs
         elif isinstance(rhs, sp.Symbol):
             return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
-        else:
-            new_args = [self.process_expression(arg) for arg in rhs.args]
+        elif isinstance(rhs, sp.Number):
+            return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
+        elif isinstance(rhs, sp.Mul):
+            new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args]
             return rhs.func(*new_args) if new_args else rhs
+        else:
+            if isinstance(rhs, sp.Pow):
+                # don't process exponents -> they should remain integers
+                return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1])
+            else:
+                new_args = [self.process_expression(arg) for arg in rhs.args]
+                return rhs.func(*new_args) if new_args else rhs
 
     @property
     def fields_written(self):
@@ -800,10 +809,13 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
 
 
 def insert_casts(node):
-    """Checks the types and inserts casts and pointer arithmetic where necessary
+    """Checks the types and inserts casts and pointer arithmetic where necessary.
 
-    :param node: the head node of the ast
-    :return: modified ast
+    Args:
+        node: the head node of the ast
+
+    Returns:
+        modified AST
     """
     def cast(zipped_args_types, target_dtype):
         """
@@ -839,7 +851,7 @@ def insert_casts(node):
         new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
         return pointer_arithmetic_func(pointer, new_args)
 
-    if isinstance(node, sp.AtomicExpr):
+    if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
         return node
     args = []
     for arg in node.args:
-- 
GitLab