From c378ca19e00b7f8dd7f68d885f02bb666d3ddef8 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Mon, 14 May 2018 17:44:50 +0200
Subject: [PATCH] Fixes in vectorization to also support float kernels

---
 backends/cbackend.py | 2 +-
 cpu/vectorization.py | 2 +-
 data_types.py        | 7 +++----
 3 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/backends/cbackend.py b/backends/cbackend.py
index 00da70f27..cfa8a4a94 100644
--- a/backends/cbackend.py
+++ b/backends/cbackend.py
@@ -262,7 +262,7 @@ class CustomSympyPrinter(CCodePrinter):
 
     def _typed_number(self, number, dtype):
         res = self._print(number)
-        if dtype.is_float:
+        if dtype.is_float():
             if dtype == self._float_type:
                 if '.' not in res:
                     res += ".0f"
diff --git a/cpu/vectorization.py b/cpu/vectorization.py
index a8542d0af..96486b3ea 100644
--- a/cpu/vectorization.py
+++ b/cpu/vectorization.py
@@ -35,7 +35,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
     elif nontemporal is True:
         nontemporal = all_fields
 
-    field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float)
+    field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
     if len(field_float_dtypes) != 1:
         raise NotImplementedError("Cannot vectorize kernels that contain accesses "
                                   "to differently typed floating point fields")
diff --git a/data_types.py b/data_types.py
index d3dad765d..4782234d5 100644
--- a/data_types.py
+++ b/data_types.py
@@ -276,6 +276,8 @@ def collate_types(types):
     # now we should have a list of basic types - struct types are not yet supported
     assert all(type(t) is BasicType for t in types)
 
+    if any(t.is_float() for t in types):
+        types = tuple(t for t in types if t.is_float())
     # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
     result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
     result = BasicType(result_numpy_type)
@@ -289,10 +291,7 @@ def get_type_of_expression(expr):
     from pystencils.astnodes import ResolvedFieldAccess
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
-        if expr == 1 or expr == -1:
-            return create_type("int16")
-        else:
-            return create_type("int")
+        return create_type("int")
     elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
         return create_type("double")
     elif isinstance(expr, ResolvedFieldAccess):
-- 
GitLab