diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index d590fd87c5b0c2cf8fb25739bb45f3371bd52daa..cbcc62f8654840f2da3370b8c2206c98dc8b9fa3 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter): def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: - return self._typed_number(expr.evalf(), get_type_of_expression(expr)) + return self._typed_number(expr.evalf(), get_type_of_expression(expr.base)) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index cb8f80fd7c0ba1346cf945cfaf4428dfa673aa89..c63ab6923fe73668a5666eca227cd73c7b327e57 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -87,10 +87,25 @@ def test_sqrt_of_integer(): assignments = [ps.Assignment(tmp, sp.sqrt(3)), ps.Assignment(f[0], tmp)] - arr = np.array([1], dtype=np.float64) + arr_double = np.array([1], dtype=np.float64) kernel = ps.create_kernel(assignments).compile() - kernel(f=arr) - assert 1.7 < arr[0] < 1.8 + kernel(f=arr_double) + assert 1.7 < arr_double[0] < 1.8 + + f = ps.fields("f: float32[1D]") + tmp = sp.symbols("tmp") + + assignments = [ps.Assignment(tmp, sp.sqrt(3)), + ps.Assignment(f[0], tmp)] + arr_single = np.array([1], dtype=np.float32) + config = ps.CreateKernelConfig(data_type="float32") + kernel = ps.create_kernel(assignments, config=config).compile() + kernel(f=arr_single) + + code = ps.get_code_str(kernel.ast) + + assert "1.7320508075688772f" in code + assert 1.7 < arr_single[0] < 1.8 def test_integer_comparision():