Commit c378ca19 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent 27cf4f19
......@@ -262,7 +262,7 @@ class CustomSympyPrinter(CCodePrinter):
def _typed_number(self, number, dtype):
res = self._print(number)
if dtype.is_float:
if dtype.is_float():
if dtype == self._float_type:
if '.' not in res:
res += ".0f"
......@@ -35,7 +35,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
elif nontemporal is True:
nontemporal = all_fields
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float)
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields")
......@@ -276,6 +276,8 @@ def collate_types(types):
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
if any(t.is_float() for t in types):
types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
......@@ -289,10 +291,7 @@ def get_type_of_expression(expr):
from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
if expr == 1 or expr == -1:
return create_type("int16")
return create_type("int")
return create_type("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double")
elif isinstance(expr, ResolvedFieldAccess):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment