Skip to content
Snippets Groups Projects
Commit e7a8d3ce authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'fix-get_type_of_expression-for-constants' into 'master'

Fix get_type_of_expression for constants like sympy.pi

See merge request !35
parents e3d8f12b b87eeadf
No related merge requests found
......@@ -378,15 +378,15 @@ def collate_types(types):
@memorycache(maxsize=2048)
def get_type_of_expression(expr):
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type("int")
return create_type(default_int_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double")
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
......@@ -416,8 +416,15 @@ def get_type_of_expression(expr):
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
elif isinstance(expr, sp.Expr):
types = tuple(get_type_of_expression(a) for a in expr.args)
return collate_types(types)
expr: sp.Expr
if expr.args:
types = tuple(get_type_of_expression(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
......
from pystencils import data_types
from pystencils.data_types import *
import sympy as sp
def test_parsing():
......@@ -19,3 +21,16 @@ def test_collation():
assert collate_types([double_type, float_type]) == double_type
assert collate_types([double4_type, float_type]) == double4_type
assert collate_types([double4_type, float4_type]) == double4_type
def test_dtype_of_constants():
# Some come constants are neither of type Integer,Float,Rational and don't have args
# >>> isinstance(pi, Integer)
# False
# >>> isinstance(pi, Float)
# False
# >>> isinstance(pi, Rational)
# False
# >>> pi.args
# ()
get_type_of_expression(sp.pi)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment