Skip to content
Snippets Groups Projects
Commit 9ded23ce authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed Wrong fString in Cuda Backend

parent 82af488a
Branches
Tags
No related merge requests found
......@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt):
return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})"
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_inv_sqrt):
return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})"
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
return super()._print_Function(expr)
......@@ -12,6 +12,7 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__fsqrt_rn' in code_str
......@@ -21,6 +22,7 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__frsqrt_rn' in code_str
......@@ -34,5 +36,6 @@ def test_fast_divisions():
assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__fdividef' in code_str
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment