Skip to content
Snippets Groups Projects

Fix get_type_of_expression for constants like sympy.pi

Compare and
2 files
+ 27
5
Preferences
Compare changes
Files
2
+ 12
5
@@ -378,15 +378,15 @@ def collate_types(types):
@@ -378,15 +378,15 @@ def collate_types(types):
@memorycache(maxsize=2048)
@memorycache(maxsize=2048)
def get_type_of_expression(expr):
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.cpu.vectorization import vec_all, vec_any
expr = sp.sympify(expr)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
if isinstance(expr, sp.Integer):
return create_type("int")
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double")
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
elif isinstance(expr, TypedSymbol):
@@ -416,8 +416,15 @@ def get_type_of_expression(expr):
@@ -416,8 +416,15 @@ def get_type_of_expression(expr):
elif isinstance(expr, sp.Pow):
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
return get_type_of_expression(expr.args[0])
elif isinstance(expr, sp.Expr):
elif isinstance(expr, sp.Expr):
types = tuple(get_type_of_expression(a) for a in expr.args)
expr: sp.Expr
return collate_types(types)
if expr.args:
 
types = tuple(get_type_of_expression(a) for a in expr.args)
 
return collate_types(types)
 
else:
 
if expr.is_integer:
 
return create_type(default_int_type)
 
else:
 
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
raise NotImplementedError("Could not determine type for", expr, type(expr))