Commit c7f0ad48 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Extend test_homography to support also constant H

parent 585ac16b
Pipeline #22395 failed with stage
in 60 minutes and 1 second
...@@ -16,7 +16,6 @@ import sympy ...@@ -16,7 +16,6 @@ import sympy
import pystencils import pystencils
import pystencils_reco.transforms import pystencils_reco.transforms
from pystencils.data_types import create_type
from pystencils_reco import crazy from pystencils_reco import crazy
from pystencils_reco._projective_matrix import ProjectiveMatrix from pystencils_reco._projective_matrix import ProjectiveMatrix
from pystencils_reco.filters import gauss_filter from pystencils_reco.filters import gauss_filter
...@@ -42,7 +41,8 @@ def test_superresolution(): ...@@ -42,7 +41,8 @@ def test_superresolution():
pyconrad.show_everything() pyconrad.show_everything()
def test_torch_simple(): @pytest.mark.parametrize('constant_h', ('constant_h', False))
def test_torch_simple(constant_h):
import pytest import pytest
pytest.importorskip("torch") pytest.importorskip("torch")
...@@ -50,18 +50,21 @@ def test_torch_simple(): ...@@ -50,18 +50,21 @@ def test_torch_simple():
x, y = pystencils.fields('x,y: float32[2d]') x, y = pystencils.fields('x,y: float32[2d]')
h = pystencils.fields('h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]')
@crazy @crazy
def move(x, y): def move(x, y):
h = pystencils.fields('h(8): float32[2d]') A = sympy.Matrix([[h[0].center, h[1].center, h[2].center],
A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)], [h[3].center, h[4].center, h[5].center],
[h.center(3), h.center(4), h.center(5)], [h[6].center, h[7].center, 1]])
[h.center(6), h.center(7), 1]])
return { return {
y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2)) y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2))
} }
kernel = move(x, y).create_pytorch_op() if constant_h:
kernel = move(x, y).create_pytorch_op(constant_fields=h)
else:
kernel = move(x, y).create_pytorch_op()
pystencils.autodiff.show_code(kernel.ast) pystencils.autodiff.show_code(kernel.ast)
x = torch.ones((10, 40)).cuda() x = torch.ones((10, 40)).cuda()
...@@ -83,11 +86,6 @@ def test_torch_simple(): ...@@ -83,11 +86,6 @@ def test_torch_simple():
def test_torch_matrix(): def test_torch_matrix():
import pytest
pytest.importorskip("torch")
import torch
# x, y = torch.zeros((20, 20)), torch.zeros((20, 20)) # x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
x, y = pystencils.fields('x,y: float32[2d]') x, y = pystencils.fields('x,y: float32[2d]')
a = sympy.Symbol('a') a = sympy.Symbol('a')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment