From c378ca19e00b7f8dd7f68d885f02bb666d3ddef8 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Mon, 14 May 2018 17:44:50 +0200 Subject: [PATCH] Fixes in vectorization to also support float kernels --- backends/cbackend.py | 2 +- cpu/vectorization.py | 2 +- data_types.py | 7 +++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/backends/cbackend.py b/backends/cbackend.py index 00da70f27..cfa8a4a94 100644 --- a/backends/cbackend.py +++ b/backends/cbackend.py @@ -262,7 +262,7 @@ class CustomSympyPrinter(CCodePrinter): def _typed_number(self, number, dtype): res = self._print(number) - if dtype.is_float: + if dtype.is_float(): if dtype == self._float_type: if '.' not in res: res += ".0f" diff --git a/cpu/vectorization.py b/cpu/vectorization.py index a8542d0af..96486b3ea 100644 --- a/cpu/vectorization.py +++ b/cpu/vectorization.py @@ -35,7 +35,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', elif nontemporal is True: nontemporal = all_fields - field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float) + field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float()) if len(field_float_dtypes) != 1: raise NotImplementedError("Cannot vectorize kernels that contain accesses " "to differently typed floating point fields") diff --git a/data_types.py b/data_types.py index d3dad765d..4782234d5 100644 --- a/data_types.py +++ b/data_types.py @@ -276,6 +276,8 @@ def collate_types(types): # now we should have a list of basic types - struct types are not yet supported assert all(type(t) is BasicType for t in types) + if any(t.is_float() for t in types): + types = tuple(t for t in types if t.is_float()) # use numpy collation -> create type from numpy type -> and, put vector type around if necessary result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) result = BasicType(result_numpy_type) @@ -289,10 +291,7 @@ def get_type_of_expression(expr): from pystencils.astnodes import ResolvedFieldAccess expr = sp.sympify(expr) if isinstance(expr, sp.Integer): - if expr == 1 or expr == -1: - return create_type("int16") - else: - return create_type("int") + return create_type("int") elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): return create_type("double") elif isinstance(expr, ResolvedFieldAccess): -- GitLab