From b87eeadfa71011e33fc465bc938fdcac6f66f576 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sun, 18 Aug 2019 14:37:15 +0200
Subject: [PATCH] Fix get_type_of_expression for constants like sympy.pi

Problem some constant expressions are neither Float,Integer,Rational but
don't have arguments.

>>> isinstance(pi, Integer)
False
>>> isinstance(pi, Float)
False
>>> isinstance(pi, Rational)
False
>>> pi.args
()
---
 pystencils/data_types.py       | 17 ++++++++++++-----
 pystencils_tests/test_types.py | 15 +++++++++++++++
 2 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 23dcf4f..86ce174 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -378,15 +378,15 @@ def collate_types(types):
 
 
 @memorycache(maxsize=2048)
-def get_type_of_expression(expr):
+def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
     from pystencils.astnodes import ResolvedFieldAccess
     from pystencils.cpu.vectorization import vec_all, vec_any
 
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
-        return create_type("int")
+        return create_type(default_int_type)
     elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
-        return create_type("double")
+        return create_type(default_float_type)
     elif isinstance(expr, ResolvedFieldAccess):
         return expr.field.dtype
     elif isinstance(expr, TypedSymbol):
@@ -416,8 +416,15 @@ def get_type_of_expression(expr):
     elif isinstance(expr, sp.Pow):
         return get_type_of_expression(expr.args[0])
     elif isinstance(expr, sp.Expr):
-        types = tuple(get_type_of_expression(a) for a in expr.args)
-        return collate_types(types)
+        expr: sp.Expr
+        if expr.args:
+            types = tuple(get_type_of_expression(a) for a in expr.args)
+            return collate_types(types)
+        else:
+            if expr.is_integer:
+                return create_type(default_int_type)
+            else:
+                return create_type(default_float_type)
 
     raise NotImplementedError("Could not determine type for", expr, type(expr))
 
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index 4b28c5a..887f802 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -1,5 +1,7 @@
 from pystencils import data_types
 from pystencils.data_types import *
+import sympy as sp
+
 
 
 def test_parsing():
@@ -19,3 +21,16 @@ def test_collation():
     assert collate_types([double_type, float_type]) == double_type
     assert collate_types([double4_type, float_type]) == double4_type
     assert collate_types([double4_type, float4_type]) == double4_type
+
+def test_dtype_of_constants():
+
+    # Some come constants are neither of type Integer,Float,Rational and don't have args
+    # >>> isinstance(pi, Integer)
+    # False
+    # >>> isinstance(pi, Float)
+    # False
+    # >>> isinstance(pi, Rational)
+    # False
+    # >>> pi.args
+    # ()
+    get_type_of_expression(sp.pi)
-- 
GitLab