Skip to content
Snippets Groups Projects
Commit 9ded23ce authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed Wrong fString in Cuda Backend

parent 82af488a
No related merge requests found
...@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})" return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})" return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
return super()._print_Function(expr) return super()._print_Function(expr)
...@@ -12,6 +12,7 @@ def test_fast_sqrt(): ...@@ -12,6 +12,7 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast)
assert '__fsqrt_rn' in code_str assert '__fsqrt_rn' in code_str
...@@ -21,6 +22,7 @@ def test_fast_sqrt(): ...@@ -21,6 +22,7 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], []) ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1 assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast)
assert '__frsqrt_rn' in code_str assert '__frsqrt_rn' in code_str
...@@ -34,5 +36,6 @@ def test_fast_divisions(): ...@@ -34,5 +36,6 @@ def test_fast_divisions():
assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1 assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu') ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast)
assert '__fdividef' in code_str assert '__fdividef' in code_str
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment