Skip to content
Snippets Groups Projects
Commit dbc890ac authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test_cuda_known_functions.py

parent e5700eb4
No related merge requests found
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy
import pystencils
from pystencils.astnodes import get_dummy_symbol
from pystencils.backends.cuda_backend import CudaSympyPrinter
from pystencils.data_types import address_of
def test_cuda_known_functions():
printer = CudaSympyPrinter()
print(printer.known_functions)
x, y = pystencils.fields('x,y: float32 [2d]')
assignments = pystencils.AssignmentCollection({
get_dummy_symbol(): sympy.Function('atomicAdd')(address_of(y.center()), 2),
y.center(): sympy.Function('rsqrtf')(x[0, 0])
})
ast = pystencils.create_kernel(assignments, 'gpu')
print(pystencils.show_code(ast))
kernel = ast.compile()
assert(kernel is not None)
def test_cuda_but_not_c():
x, y = pystencils.fields('x,y: float32 [2d]')
assignments = pystencils.AssignmentCollection({
get_dummy_symbol(): sympy.Function('atomicAdd')(address_of(y.center()), 2),
y.center(): sympy.Function('rsqrtf')(x[0, 0])
})
ast = pystencils.create_kernel(assignments, 'cpu')
code = str(pystencils.show_code(ast))
assert "Not supported" in code
def test_cuda_unknown():
x, y = pystencils.fields('x,y: float32 [2d]')
assignments = pystencils.AssignmentCollection({
get_dummy_symbol(): sympy.Function('wtf')(address_of(y.center()), 2),
})
ast = pystencils.create_kernel(assignments, 'gpu')
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" in code
def main():
test_cuda_known_functions()
test_cuda_but_not_c()
test_cuda_unknown()
if __name__ == '__main__':
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment