diff --git a/pystencils_tests/test_cudagpu.py b/pystencils_tests/test_cudagpu.py index 35cd8cf9004533873330519d42e1a67b5d449f48..8de52ba490bc2fdc7914d8a194776937562e9481 100644 --- a/pystencils_tests/test_cudagpu.py +++ b/pystencils_tests/test_cudagpu.py @@ -4,7 +4,7 @@ from pystencils import Field, Assignment from pystencils.simp import sympy_cse_on_assignment_list from pystencils.gpucuda.indexing import LineIndexing from pystencils.slicing import remove_ghost_layers, add_ghost_layers, make_slice -from pystencils.gpucuda import make_python_function, create_cuda_kernel +from pystencils.gpucuda import make_python_function, create_cuda_kernel, BlockIndexing import pycuda.gpuarray as gpuarray from scipy.ndimage import convolve @@ -145,3 +145,8 @@ def test_periodicity(): periodic_gpu_kernel(pdfs=arr_gpu) arr_gpu.get(gpu_result) np.testing.assert_equal(cpu_result, gpu_result) + + +def test_block_size_limiting(): + res = BlockIndexing.limit_block_size_to_device_maximum((4096, 4096, 4096)) + assert all(r < 4096 for r in res) diff --git a/pystencils_tests/test_fast_approximation.py b/pystencils_tests/test_fast_approximation.py index d0cf33d8a042072caad39a8a62c111688fbe26bb..bc07ae9b8c133fb136fcca25aeea00c256e181d1 100644 --- a/pystencils_tests/test_fast_approximation.py +++ b/pystencils_tests/test_fast_approximation.py @@ -10,11 +10,18 @@ def test_fast_sqrt(): assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1 + ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') + code_str = str(ps.show_code(ast)) + assert '__fsqrt_rn' in code_str - expr = 3 / sp.sqrt(f[0, 0] + f[1, 0]) + expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0])) assert len(insert_fast_sqrts(expr).atoms(fast_inv_sqrt)) == 1 + ac = ps.AssignmentCollection([expr], []) assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1 + ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') + code_str = str(ps.show_code(ast)) + assert '__frsqrt_rn' in code_str def test_fast_divisions():