From c4e92d45d73537fe837b9e52073c88712585065a Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 27 Nov 2019 18:19:23 +0100 Subject: [PATCH] Fix: type of sqrt(int) was int not floating point type --- pystencils/data_types.py | 8 +++++++- pystencils_tests/test_types.py | 25 ++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 0bed536e0..4899bff2d 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -548,7 +548,13 @@ def get_type_of_expression(expr, if vec_args: result = VectorType(result, width=vec_args[0].width) return result - elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)): + elif isinstance(expr, sp.Pow): + base_type = get_type(expr.args[0]) + if expr.exp.is_integer: + return base_type + else: + return collate_types([create_type(default_float_type), base_type]) + elif isinstance(expr, (sp.Sum, sp.Product)): return get_type(expr.args[0]) elif isinstance(expr, sp.Expr): expr: sp.Expr diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 322e04db1..4deb69acd 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -1,8 +1,8 @@ import sympy as sp - +import numpy as np +import pystencils as ps from pystencils import data_types -from pystencils.data_types import * -from pystencils.kernelparameters import FieldShapeSymbol +from pystencils.data_types import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type def test_parsing(): @@ -25,7 +25,6 @@ def test_collation(): def test_dtype_of_constants(): - # Some come constants are neither of type Integer,Float,Rational and don't have args # >>> isinstance(pi, Integer) # False @@ -39,13 +38,25 @@ def test_dtype_of_constants(): def test_assumptions(): - - x = pystencils.fields('x: float32[3d]') + x = ps.fields('x: float32[3d]') assert x.shape[0].is_nonnegative assert (2 * x.shape[0]).is_nonnegative assert (2 * x.shape[0]).is_integer - assert(TypedSymbol('a', create_type('uint64'))).is_nonnegative + assert (TypedSymbol('a', create_type('uint64'))).is_nonnegative assert (TypedSymbol('a', create_type('uint64'))).is_positive is None assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive assert (x.shape[0] + 1).is_real + + +def test_sqrt_of_integer(): + """Regression test for bug where sqrt(3) was classified as integer""" + f = ps.fields("f: [1D]") + tmp = sp.symbols("tmp") + + assignments = [ps.Assignment(tmp, sp.sqrt(3)), + ps.Assignment(f[0], tmp)] + arr = np.array([1], dtype=np.float64) + kernel = ps.create_kernel(assignments).compile() + kernel(f=arr) + assert 1.7 < arr[0] < 1.8 -- GitLab