From b87eeadfa71011e33fc465bc938fdcac6f66f576 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Sun, 18 Aug 2019 14:37:15 +0200 Subject: [PATCH] Fix get_type_of_expression for constants like sympy.pi Problem some constant expressions are neither Float,Integer,Rational but don't have arguments. >>> isinstance(pi, Integer) False >>> isinstance(pi, Float) False >>> isinstance(pi, Rational) False >>> pi.args () --- pystencils/data_types.py | 17 ++++++++++++----- pystencils_tests/test_types.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 23dcf4f12..86ce1747d 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -378,15 +378,15 @@ def collate_types(types): @memorycache(maxsize=2048) -def get_type_of_expression(expr): +def get_type_of_expression(expr, default_float_type='double', default_int_type='int'): from pystencils.astnodes import ResolvedFieldAccess from pystencils.cpu.vectorization import vec_all, vec_any expr = sp.sympify(expr) if isinstance(expr, sp.Integer): - return create_type("int") + return create_type(default_int_type) elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): - return create_type("double") + return create_type(default_float_type) elif isinstance(expr, ResolvedFieldAccess): return expr.field.dtype elif isinstance(expr, TypedSymbol): @@ -416,8 +416,15 @@ def get_type_of_expression(expr): elif isinstance(expr, sp.Pow): return get_type_of_expression(expr.args[0]) elif isinstance(expr, sp.Expr): - types = tuple(get_type_of_expression(a) for a in expr.args) - return collate_types(types) + expr: sp.Expr + if expr.args: + types = tuple(get_type_of_expression(a) for a in expr.args) + return collate_types(types) + else: + if expr.is_integer: + return create_type(default_int_type) + else: + return create_type(default_float_type) raise NotImplementedError("Could not determine type for", expr, type(expr)) diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 4b28c5a0b..887f802c9 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,5 +1,7 @@ from pystencils import data_types from pystencils.data_types import * +import sympy as sp + def test_parsing(): @@ -19,3 +21,16 @@ def test_collation(): assert collate_types([double_type, float_type]) == double_type assert collate_types([double4_type, float_type]) == double4_type assert collate_types([double4_type, float4_type]) == double4_type + +def test_dtype_of_constants(): + + # Some come constants are neither of type Integer,Float,Rational and don't have args + # >>> isinstance(pi, Integer) + # False + # >>> isinstance(pi, Float) + # False + # >>> isinstance(pi, Rational) + # False + # >>> pi.args + # () + get_type_of_expression(sp.pi) -- GitLab