test_fast_approximation.py 1.47 KB
Newer Older
1
import sympy as sp
Martin Bauer's avatar
Martin Bauer committed
2

3
import pystencils as ps
Martin Bauer's avatar
Martin Bauer committed
4
5
from pystencils.fast_approximation import (
    fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts)
6
7
8
9
10
11
12
13


def test_fast_sqrt():
    f, g = ps.fields("f, g: double[2D]")
    expr = sp.sqrt(f[0, 0] + f[1, 0])

    assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
    assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
14
15
16
    ast_gpu = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
    ast_gpu.compile()
    code_str = ps.get_code_str(ast_gpu)
17
    assert '__fsqrt_rn' in code_str
18

19
    expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0]))
20
    assert len(insert_fast_sqrts(expr).atoms(fast_inv_sqrt)) == 1
21

22
23
    ac = ps.AssignmentCollection([expr], [])
    assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
24
25
26
    ast_gpu = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
    ast_gpu.compile()
    code_str = ps.get_code_str(ast_gpu)
27
    assert '__frsqrt_rn' in code_str
28
29
30
31
32
33
34
35
36
37
38


def test_fast_divisions():
    f, g = ps.fields("f, g: double[2D]")
    expr = f[0, 0] / f[1, 0]
    assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1

    expr = 1 / f[0, 0] * 2 / f[0, 1]
    assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1

    ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu')
39
    ast.compile()
40
    code_str = ps.get_code_str(ast)
41
    assert '__fdividef' in code_str