From 7503c866ef0468ac3699d23f371cd67305987285 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 15 Jul 2019 10:55:51 +0200 Subject: [PATCH] Add test for custom backends --- pystencils_tests/test_custom_backends.py | 60 ++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 pystencils_tests/test_custom_backends.py diff --git a/pystencils_tests/test_custom_backends.py b/pystencils_tests/test_custom_backends.py new file mode 100644 index 000000000..f68696f13 --- /dev/null +++ b/pystencils_tests/test_custom_backends.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +from subprocess import CalledProcessError + +import pycuda.driver +import pytest +import sympy + +import pystencils +import pystencils.cpu.cpujit +import pystencils.gpucuda.cudajit +from pystencils.backends.cbackend import CBackend +from pystencils.backends.cuda_backend import CudaBackend + + +class ScreamingBackend(CBackend): + + def _print(self, node): + normal_code = super()._print(node) + return normal_code.upper() + + +class ScreamingGpuBackend(CudaBackend): + + def _print(self, node): + normal_code = super()._print(node) + return normal_code.upper() + + +def test_custom_backends(): + z, x, y = pystencils.fields("z, y, x: [2d]") + + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments, target='cpu') + print(pystencils.show_code(ast, ScreamingBackend())) + with pytest.raises(CalledProcessError): + pystencils.cpu.cpujit.make_python_function(ast, custom_backend=ScreamingBackend()) + + ast = pystencils.create_kernel(normal_assignments, target='gpu') + print(pystencils.show_code(ast, ScreamingGpuBackend())) + with pytest.raises(pycuda.driver.CompileError): + pystencils.gpucuda.cudajit.make_python_function(ast, custom_backend=ScreamingGpuBackend()) + + +def main(): + + test_custom_backends() + + +if __name__ == "__main__": + main() -- GitLab