From 3bb88ad133614c9e1d87e73d5271a06077aaf770 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Wed, 17 Nov 2021 14:57:24 +0100
Subject: [PATCH] Fixed integer square root

---
 pystencils/backends/cbackend.py |  2 +-
 pystencils_tests/test_types.py  | 21 ++++++++++++++++++---
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index d590fd87c..cbcc62f86 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter):
     def _print_Pow(self, expr):
         """Don't use std::pow function, for small integer exponents, write as multiplication"""
         if not expr.free_symbols:
-            return self._typed_number(expr.evalf(), get_type_of_expression(expr))
+            return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
 
         if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
             return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py
index cb8f80fd7..c63ab6923 100644
--- a/pystencils_tests/test_types.py
+++ b/pystencils_tests/test_types.py
@@ -87,10 +87,25 @@ def test_sqrt_of_integer():
 
     assignments = [ps.Assignment(tmp, sp.sqrt(3)),
                    ps.Assignment(f[0], tmp)]
-    arr = np.array([1], dtype=np.float64)
+    arr_double = np.array([1], dtype=np.float64)
     kernel = ps.create_kernel(assignments).compile()
-    kernel(f=arr)
-    assert 1.7 < arr[0] < 1.8
+    kernel(f=arr_double)
+    assert 1.7 < arr_double[0] < 1.8
+
+    f = ps.fields("f: float32[1D]")
+    tmp = sp.symbols("tmp")
+
+    assignments = [ps.Assignment(tmp, sp.sqrt(3)),
+                   ps.Assignment(f[0], tmp)]
+    arr_single = np.array([1], dtype=np.float32)
+    config = ps.CreateKernelConfig(data_type="float32")
+    kernel = ps.create_kernel(assignments, config=config).compile()
+    kernel(f=arr_single)
+
+    code = ps.get_code_str(kernel.ast)
+
+    assert "1.7320508075688772f" in code
+    assert 1.7 < arr_single[0] < 1.8
 
 
 def test_integer_comparision():
-- 
GitLab