test_dtype_check.py 1.09 KB
Newer Older
1
2
3
import numpy as np
import pytest

Martin Bauer's avatar
Martin Bauer committed
4
5
import pystencils

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

def test_dtype_check_wrong_type():
    array = np.ones((10, 20)).astype(np.float32)
    output = np.zeros_like(array)
    x, y = pystencils.fields('x,y: [2D]')
    stencil = [[1, 1, 1],
               [1, 1, 1],
               [1, 1, 1]]
    assignment = pystencils.assignment.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil))
    kernel = pystencils.create_kernel([assignment]).compile()

    with pytest.raises(ValueError) as e:
        kernel(x=array, y=output)
    assert 'Wrong data type' in str(e)


def test_dtype_check_correct_type():
    array = np.ones((10, 20)).astype(np.float64)
    output = np.zeros_like(array)
    x, y = pystencils.fields('x,y: [2D]')
    stencil = [[1, 1, 1],
               [1, 1, 1],
               [1, 1, 1]]
    assignment = pystencils.assignment.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil))
    kernel = pystencils.create_kernel([assignment]).compile()
    kernel(x=array, y=output)
    assert np.allclose(output[1:-1, 1:-1], np.ones_like(output[1:-1, 1:-1]))