Commit 7644ab1f authored by Jan Hönig's avatar Jan Hönig
Browse files

Merge branch 'IntegerSquareRoot' into 'master'

Fixed integer square root

See merge request pycodegen/pystencils!274
parents 43393627 3bb88ad1
......@@ -444,7 +444,7 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr))
return self._typed_number(expr.evalf(), get_type_of_expression(expr.base))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
......
......@@ -87,10 +87,25 @@ def test_sqrt_of_integer():
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr = np.array([1], dtype=np.float64)
arr_double = np.array([1], dtype=np.float64)
kernel = ps.create_kernel(assignments).compile()
kernel(f=arr)
assert 1.7 < arr[0] < 1.8
kernel(f=arr_double)
assert 1.7 < arr_double[0] < 1.8
f = ps.fields("f: float32[1D]")
tmp = sp.symbols("tmp")
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr_single = np.array([1], dtype=np.float32)
config = ps.CreateKernelConfig(data_type="float32")
kernel = ps.create_kernel(assignments, config=config).compile()
kernel(f=arr_single)
code = ps.get_code_str(kernel.ast)
assert "1.7320508075688772f" in code
assert 1.7 < arr_single[0] < 1.8
def test_integer_comparision():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment