diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 6c2ab0102c8f25fe277cf16d674b87655e03212d..ccc59b60ab1b015af6c51535ce4a66a79baf9490 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -4,14 +4,14 @@ from functools import partial from typing import Tuple import numpy as np -import sympy as sp -import sympy.codegen.ast -from sympy.core.cache import cacheit -from sympy.logic.boolalg import Boolean import pystencils +import sympy as sp +import sympy.codegen.ast from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.utils import all_equal +from sympy.core.cache import cacheit +from sympy.logic.boolalg import Boolean try: import llvmlite.ir as ir @@ -432,7 +432,9 @@ def peel_off_type(dtype, type_to_peel_off): def collate_types(types, forbid_collation_to_complex=False, - forbid_collation_to_float=False): + forbid_collation_to_float=False, + default_float_type='float64', + default_int_type='int64'): """ Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. @@ -443,14 +445,14 @@ def collate_types(types, if not np.issubdtype(t.numpy_dtype, np.complexfloating) ] if not types: - return create_type(np.float64) + return create_type(default_float_type) if forbid_collation_to_float: types = [ t for t in types if not np.issubdtype(t.numpy_dtype, np.floating) ] if not types: - return create_type(np.int32) + return create_type(default_int_type) # Pointer arithmetic case i.e. pointer + integer is allowed if any(type(t) is PointerType for t in types): @@ -549,7 +551,9 @@ def get_type_of_expression(expr, return collate_types( types, forbid_collation_to_complex=expr.is_real is True, - forbid_collation_to_float=expr.is_integer is True) + forbid_collation_to_float=expr.is_integer is True, + default_float_type=default_float_type, + default_int_type=default_int_type) else: if expr.is_integer: return create_type(default_int_type)