Commit 7503c866 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test for custom backends

parent 9bb1e142
# -*- coding: utf-8 -*-
# Copyright © 2019 Stephan Seitz <>
# Distributed under terms of the GPLv3 license.
from subprocess import CalledProcessError
import pycuda.driver
import pytest
import sympy
import pystencils
import pystencils.cpu.cpujit
import pystencils.gpucuda.cudajit
from pystencils.backends.cbackend import CBackend
from pystencils.backends.cuda_backend import CudaBackend
class ScreamingBackend(CBackend):
def _print(self, node):
normal_code = super()._print(node)
return normal_code.upper()
class ScreamingGpuBackend(CudaBackend):
def _print(self, node):
normal_code = super()._print(node)
return normal_code.upper()
def test_custom_backends():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments, target='cpu')
print(pystencils.show_code(ast, ScreamingBackend()))
with pytest.raises(CalledProcessError):
pystencils.cpu.cpujit.make_python_function(ast, custom_backend=ScreamingBackend())
ast = pystencils.create_kernel(normal_assignments, target='gpu')
print(pystencils.show_code(ast, ScreamingGpuBackend()))
with pytest.raises(pycuda.driver.CompileError):
pystencils.gpucuda.cudajit.make_python_function(ast, custom_backend=ScreamingGpuBackend())
def main():
if __name__ == "__main__":
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment