diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 0bed536e0ff93ac963a3202840c942e17d5a2832..4899bff2d4d26839796f00b6b6e49b9b0c71c9bb 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -548,7 +548,13 @@ def get_type_of_expression(expr, if vec_args: result = VectorType(result, width=vec_args[0].width) return result - elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)): + elif isinstance(expr, sp.Pow): + base_type = get_type(expr.args[0]) + if expr.exp.is_integer: + return base_type + else: + return collate_types([create_type(default_float_type), base_type]) + elif isinstance(expr, (sp.Sum, sp.Product)): return get_type(expr.args[0]) elif isinstance(expr, sp.Expr): expr: sp.Expr diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 322e04db1d47b93efa89d9d024f4dfda325b5dc4..4deb69acd3ffc2051d2eff742479d5130b0fd148 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,8 +1,8 @@ import sympy as sp - +import numpy as np +import pystencils as ps from pystencils import data_types -from pystencils.data_types import * -from pystencils.kernelparameters import FieldShapeSymbol +from pystencils.data_types import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type def test_parsing(): @@ -25,7 +25,6 @@ def test_collation(): def test_dtype_of_constants(): - # Some come constants are neither of type Integer,Float,Rational and don't have args # >>> isinstance(pi, Integer) # False @@ -39,13 +38,25 @@ def test_dtype_of_constants(): def test_assumptions(): - - x = pystencils.fields('x: float32[3d]') + x = ps.fields('x: float32[3d]') assert x.shape[0].is_nonnegative assert (2 * x.shape[0]).is_nonnegative assert (2 * x.shape[0]).is_integer - assert(TypedSymbol('a', create_type('uint64'))).is_nonnegative + assert (TypedSymbol('a', create_type('uint64'))).is_nonnegative assert (TypedSymbol('a', create_type('uint64'))).is_positive is None assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive assert (x.shape[0] + 1).is_real + + +def test_sqrt_of_integer(): + """Regression test for bug where sqrt(3) was classified as integer""" + f = ps.fields("f: [1D]") + tmp = sp.symbols("tmp") + + assignments = [ps.Assignment(tmp, sp.sqrt(3)), + ps.Assignment(f[0], tmp)] + arr = np.array([1], dtype=np.float64) + kernel = ps.create_kernel(assignments).compile() + kernel(f=arr) + assert 1.7 < arr[0] < 1.8