Commit 9a53b949 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

test_superresolution

parent 38baf325
Pipeline #22364 failed with stage
in 2 minutes and 47 seconds
......@@ -16,7 +16,9 @@ import sympy
import pystencils
import pystencils_reco.transforms
from pystencils.data_types import create_type
from pystencils_reco import crazy
from pystencils_reco._projective_matrix import ProjectiveMatrix
from pystencils_reco.filters import gauss_filter
from pystencils_reco.resampling import (
downsample, resample, resample_to_shape, scale_transform, translate)
......@@ -46,7 +48,37 @@ def test_torch_simple():
pytest.importorskip("torch")
import torch
x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
x, y = pystencils.fields('x,y: float32[2d]')
@crazy
def move(x, y):
h = pystencils.fields('h(8): float32[2d]')
A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)],
[h.center(3), h.center(4), h.center(5)],
[h.center(6), h.center(7), 1]])
return {
y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2))
}
kernel = move(x, y).create_pytorch_op()
pystencils.autodiff.show_code(kernel.ast)
x = torch.ones((10, 40)).cuda()
h = torch.ones((10, 40, 8)).cuda()
kernel().forward(h, x)
# kernel().forward(*([1]*9), x, y)
def test_torch_matrix():
import pytest
pytest.importorskip("torch")
import torch
# x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
x, y = pystencils.fields('x,y: float32[2d]')
a = sympy.Symbol('a')
@crazy
......@@ -55,9 +87,8 @@ def test_torch_simple():
y.center: x.interpolated_access((pystencils.x_, pystencils.y_ + a))
}
kernel = move(x, y, a).compile()
kernel = move(x, y, a).create_pytorch_op()
pystencils.autodiff.show_code(kernel.ast)
kernel().forward(x, y, 3)
def test_downsample():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment