From 4465a59575a20b69b6140fa4f525387abecb1063 Mon Sep 17 00:00:00 2001 From: markus holzer <markus.holzer@fau.de> Date: Mon, 10 Aug 2020 10:01:59 +0200 Subject: [PATCH] Added test case for vectorisation --- pystencils/backends/cbackend.py | 6 +++--- pystencils/backends/cuda_backend.py | 1 - pystencils_tests/test_fast_approximation.py | 12 ++++++------ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 9f576b4aa..aca5e5da1 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -298,7 +298,7 @@ class CBackend: return node.get_code(self._dialect, self._vector_instruction_set) def _print_SourceCodeComment(self, node): - return "/* " + node.text + " */" + return f"/* {node.text } */" def _print_EmptyLine(self, node): return "" @@ -316,7 +316,7 @@ class CBackend: result = f"if ({condition_expr})\n{true_block} " if node.false_block: false_block = self._print_Block(node.false_block) - result += "else " + false_block + result += f"else {false_block}" return result @@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter): return self._typed_number(expr.evalf(), get_type_of_expression(expr)) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" else: diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index b43cfb4ed..2d7dc579e 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter): assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given" return f"__fsqrt_rn({self._print(expr.args[0])})" elif isinstance(expr, fast_inv_sqrt): - print(len(expr.args) == 1) assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given" return f"__frsqrt_rn({self._print(expr.args[0])})" return super()._print_Function(expr) diff --git a/pystencils_tests/test_fast_approximation.py b/pystencils_tests/test_fast_approximation.py index f4d19fa19..ccd2d7b8e 100644 --- a/pystencils_tests/test_fast_approximation.py +++ b/pystencils_tests/test_fast_approximation.py @@ -11,9 +11,9 @@ def test_fast_sqrt(): assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1 - ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') - ast.compile() - code_str = ps.get_code_str(ast) + ast_gpu = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') + ast_gpu.compile() + code_str = ps.get_code_str(ast_gpu) assert '__fsqrt_rn' in code_str expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0])) @@ -21,9 +21,9 @@ def test_fast_sqrt(): ac = ps.AssignmentCollection([expr], []) assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1 - ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') - ast.compile() - code_str = ps.get_code_str(ast) + ast_gpu = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') + ast_gpu.compile() + code_str = ps.get_code_str(ast_gpu) assert '__frsqrt_rn' in code_str -- GitLab