Commit c7f0ad48 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Extend test_homography to support also constant H

parent 585ac16b
Pipeline #22395 failed with stage
in 60 minutes and 1 second
......@@ -16,7 +16,6 @@ import sympy
import pystencils
import pystencils_reco.transforms
from pystencils.data_types import create_type
from pystencils_reco import crazy
from pystencils_reco._projective_matrix import ProjectiveMatrix
from pystencils_reco.filters import gauss_filter
......@@ -42,7 +41,8 @@ def test_superresolution():
pyconrad.show_everything()
def test_torch_simple():
@pytest.mark.parametrize('constant_h', ('constant_h', False))
def test_torch_simple(constant_h):
import pytest
pytest.importorskip("torch")
......@@ -50,17 +50,20 @@ def test_torch_simple():
x, y = pystencils.fields('x,y: float32[2d]')
h = pystencils.fields('h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]')
@crazy
def move(x, y):
h = pystencils.fields('h(8): float32[2d]')
A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)],
[h.center(3), h.center(4), h.center(5)],
[h.center(6), h.center(7), 1]])
A = sympy.Matrix([[h[0].center, h[1].center, h[2].center],
[h[3].center, h[4].center, h[5].center],
[h[6].center, h[7].center, 1]])
return {
y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2))
}
if constant_h:
kernel = move(x, y).create_pytorch_op(constant_fields=h)
else:
kernel = move(x, y).create_pytorch_op()
pystencils.autodiff.show_code(kernel.ast)
......@@ -83,11 +86,6 @@ def test_torch_simple():
def test_torch_matrix():
import pytest
pytest.importorskip("torch")
import torch
# x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
x, y = pystencils.fields('x,y: float32[2d]')
a = sympy.Symbol('a')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment