From 7503c866ef0468ac3699d23f371cd67305987285 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 15 Jul 2019 10:55:51 +0200
Subject: [PATCH] Add test for custom backends

---
 pystencils_tests/test_custom_backends.py | 60 ++++++++++++++++++++++++
 1 file changed, 60 insertions(+)
 create mode 100644 pystencils_tests/test_custom_backends.py

diff --git a/pystencils_tests/test_custom_backends.py b/pystencils_tests/test_custom_backends.py
new file mode 100644
index 000000000..f68696f13
--- /dev/null
+++ b/pystencils_tests/test_custom_backends.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+#
+# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+from subprocess import CalledProcessError
+
+import pycuda.driver
+import pytest
+import sympy
+
+import pystencils
+import pystencils.cpu.cpujit
+import pystencils.gpucuda.cudajit
+from pystencils.backends.cbackend import CBackend
+from pystencils.backends.cuda_backend import CudaBackend
+
+
+class ScreamingBackend(CBackend):
+
+    def _print(self, node):
+        normal_code = super()._print(node)
+        return normal_code.upper()
+
+
+class ScreamingGpuBackend(CudaBackend):
+
+    def _print(self, node):
+        normal_code = super()._print(node)
+        return normal_code.upper()
+
+
+def test_custom_backends():
+    z, x, y = pystencils.fields("z, y, x: [2d]")
+
+    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
+        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
+
+    ast = pystencils.create_kernel(normal_assignments, target='cpu')
+    print(pystencils.show_code(ast, ScreamingBackend()))
+    with pytest.raises(CalledProcessError):
+        pystencils.cpu.cpujit.make_python_function(ast, custom_backend=ScreamingBackend())
+
+    ast = pystencils.create_kernel(normal_assignments, target='gpu')
+    print(pystencils.show_code(ast, ScreamingGpuBackend()))
+    with pytest.raises(pycuda.driver.CompileError):
+        pystencils.gpucuda.cudajit.make_python_function(ast, custom_backend=ScreamingGpuBackend())
+
+
+def main():
+
+    test_custom_backends()
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab