From 679bf618009ec1c6ebf25d273a08750ef25521b7 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 15 Jan 2024 16:08:33 +0100 Subject: [PATCH] extend printing test --- pystencils/nbackend/c_printer.py | 2 +- pystencils_tests/nbackend/test_basic_printing.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py index 4ca472a29..6872ad9df 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/c_printer.py @@ -29,7 +29,7 @@ class CPrinter: def function(self, func: PsKernelFunction) -> str: params = func.get_parameters() params_str = ", ".join(f"{p.dtype} {p.name}" for p in params) - decl = f"FUNC_PREFIX void {func.name} ( {params_str} )" + decl = f"FUNC_PREFIX void {func.name} ({params_str})" body = self.visit(func.body) return f"{decl}\n{body}" diff --git a/pystencils_tests/nbackend/test_basic_printing.py b/pystencils_tests/nbackend/test_basic_printing.py index 2394b9287..867921114 100644 --- a/pystencils_tests/nbackend/test_basic_printing.py +++ b/pystencils_tests/nbackend/test_basic_printing.py @@ -34,5 +34,10 @@ def test_basic_kernel(): printer = CPrinter() code = printer.print(func) - assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1]") >= 0 + paramlist = func.get_parameters() + params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist) + + assert code.find("(" + params_str + ")") >= 0 + + assert code.find("u_data[ctr] = u_data[ctr - 1] + u_data[ctr + 1];") >= 0 -- GitLab