Skip to content
Snippets Groups Projects
Commit 9ded23ce authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed Wrong fString in Cuda Backend

parent 82af488a
Branches
Tags
......@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt):
return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})"
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_inv_sqrt):
return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})"
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
return super()._print_Function(expr)
......@@ -12,6 +12,7 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__fsqrt_rn' in code_str
......@@ -21,6 +22,7 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__frsqrt_rn' in code_str
......@@ -34,5 +36,6 @@ def test_fast_divisions():
assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast)
assert '__fdividef' in code_str
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment