Commit 4465a595 authored by Markus Holzer's avatar Markus Holzer
Browse files

Added test case for vectorisation

parent cd3a2f3e
...@@ -298,7 +298,7 @@ class CBackend: ...@@ -298,7 +298,7 @@ class CBackend:
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set)
def _print_SourceCodeComment(self, node): def _print_SourceCodeComment(self, node):
return "/* " + node.text + " */" return f"/* {node.text } */"
def _print_EmptyLine(self, node): def _print_EmptyLine(self, node):
return "" return ""
...@@ -316,7 +316,7 @@ class CBackend: ...@@ -316,7 +316,7 @@ class CBackend:
result = f"if ({condition_expr})\n{true_block} " result = f"if ({condition_expr})\n{true_block} "
if node.false_block: if node.false_block:
false_block = self._print_Block(node.false_block) false_block = self._print_Block(node.false_block)
result += "else " + false_block result += f"else {false_block}"
return result return result
...@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) return self._typed_number(expr.evalf(), get_type_of_expression(expr))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
else: else:
......
...@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given" assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})" return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given" assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})" return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr) return super()._print_Function(expr)
...@@ -11,9 +11,9 @@ def test_fast_sqrt(): ...@@ -11,9 +11,9 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') ast_gpu = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile() ast_gpu.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast_gpu)
assert '__fsqrt_rn' in code_str assert '__fsqrt_rn' in code_str
expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0])) expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0]))
...@@ -21,9 +21,9 @@ def test_fast_sqrt(): ...@@ -21,9 +21,9 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], []) ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1 assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') ast_gpu = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile() ast_gpu.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast_gpu)
assert '__frsqrt_rn' in code_str assert '__frsqrt_rn' in code_str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment