Commit ad5ebbe8 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Update test_superresolution

parent f1caa29c
......@@ -426,3 +426,65 @@ def test_spline_diff(with_spline):
pyconrad.imshow(dh.gpu_arrays)
pyconrad.imshow(dh.gpu_arrays)
@pytest.mark.parametrize('scalar_experiment', (False,))
def test_rotation(scalar_experiment):
from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling
from pystencils_reco.resampling import rotation_transform
import torch
lenna_file = join(dirname(__file__), "test_data", "lenna.png")
lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32)
GROUNDTRUTH_ANGLE = 0.3
target = np.zeros(lenna.shape)
rotation_transform(lenna, target, GROUNDTRUTH_ANGLE)()
target = torch.Tensor(target).cuda()
dh = PyTorchDataHandling(lenna.shape)
x, y, angle = dh.add_arrays('x, y, angle')
if scalar_experiment:
var_angle = torch.zeros((), requires_grad=True)
else:
var_angle = torch.zeros(lenna.shape, requires_grad=True)
var_lenna = torch.autograd.Variable(torch.from_numpy(
lenna + np.random.randn(*lenna.shape).astype(np.float32)), requires_grad=True)
assert var_lenna.requires_grad
learning_rate = 0.1
params = (var_angle, var_lenna)
optimizer = torch.optim.Adam(params, lr=learning_rate)
assignments = rotation_transform(x, y, angle)
kernel = assignments.create_pytorch_op()
print(kernel)
kernel = kernel().call
for i in range(100000):
if scalar_experiment:
dh.cpu_arrays.angle = torch.ones(lenna.shape) * (var_angle + 0.29)
else:
dh.cpu_arrays.angle = var_angle
dh.cpu_arrays.x = var_lenna
dh.all_to_gpu()
y = dh.run_kernel(kernel)
loss = (y - target).norm()
optimizer.zero_grad()
loss.backward(retain_graph=True)
assert y.requires_grad
optimizer.step()
print(loss.cpu().detach().numpy())
pyconrad.imshow(var_lenna)
pyconrad.show_everything()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment