diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 23dcf4f1260ac75bf6c01afd037fd3cdda51f545..86ce1747d2b83cea80d1974121dbe79959c0ba8c 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -378,15 +378,15 @@ def collate_types(types): @memorycache(maxsize=2048) -def get_type_of_expression(expr): +def get_type_of_expression(expr, default_float_type='double', default_int_type='int'): from pystencils.astnodes import ResolvedFieldAccess from pystencils.cpu.vectorization import vec_all, vec_any expr = sp.sympify(expr) if isinstance(expr, sp.Integer): - return create_type("int") + return create_type(default_int_type) elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): - return create_type("double") + return create_type(default_float_type) elif isinstance(expr, ResolvedFieldAccess): return expr.field.dtype elif isinstance(expr, TypedSymbol): @@ -416,8 +416,15 @@ def get_type_of_expression(expr): elif isinstance(expr, sp.Pow): return get_type_of_expression(expr.args[0]) elif isinstance(expr, sp.Expr): - types = tuple(get_type_of_expression(a) for a in expr.args) - return collate_types(types) + expr: sp.Expr + if expr.args: + types = tuple(get_type_of_expression(a) for a in expr.args) + return collate_types(types) + else: + if expr.is_integer: + return create_type(default_int_type) + else: + return create_type(default_float_type) raise NotImplementedError("Could not determine type for", expr, type(expr)) diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 4b28c5a0b9bfe4af94d8368b2d06bcc513aface9..887f802c91eb3a82ebd8ea43f6fbc17d18d18cef 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,5 +1,7 @@ from pystencils import data_types from pystencils.data_types import * +import sympy as sp + def test_parsing(): @@ -19,3 +21,16 @@ def test_collation(): assert collate_types([double_type, float_type]) == double_type assert collate_types([double4_type, float_type]) == double4_type assert collate_types([double4_type, float4_type]) == double4_type + +def test_dtype_of_constants(): + + # Some come constants are neither of type Integer,Float,Rational and don't have args + # >>> isinstance(pi, Integer) + # False + # >>> isinstance(pi, Float) + # False + # >>> isinstance(pi, Rational) + # False + # >>> pi.args + # () + get_type_of_expression(sp.pi)