Commit 4ec6f70e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make stuff work in test_superresolution

parent af7bfff2
Pipeline #21054 failed with stage
in 6 minutes and 42 seconds
......@@ -10,6 +10,7 @@
from os.path import dirname, join
import numpy as np
import pytest
import skimage.io
import sympy
......@@ -309,7 +310,8 @@ def test_get_shift():
pyconrad.imshow(dh.gpu_arrays)
def test_get_shift_tensors():
@pytest.mark.parametrize('scalar_experiment', (False,))
def test_get_shift_tensors(scalar_experiment):
from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling
import torch
......@@ -323,47 +325,102 @@ def test_get_shift_tensors():
dh.cpu_arrays['txw'][...] = 0.7
dh.cpu_arrays['tyw'][...] = -0.7
dh.all_to_gpu()
pyconrad.imshow(dh.gpu_arrays)
kernel = pystencils_reco.AssignmentCollection({
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_),
interpolation_mode='cubic_spline')
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_))
}).create_pytorch_op()().call
dh.run_kernel(kernel)
y_array = dh.gpu_arrays['yw']
y_array = dh.run_kernel(kernel)
dh = PyTorchDataHandling(lenna.shape)
x, y, tx, ty = dh.add_arrays('x, y, tx, ty')
dh.cpu_arrays['tx'] = torch.zeros(lenna.shape, requires_grad=True)
dh.cpu_arrays['ty'] = torch.zeros(lenna.shape, requires_grad=True)
dh.cpu_arrays['x'] = lenna
dh.all_to_gpu()
kernel = pystencils_reco.AssignmentCollection({
if scalar_experiment:
var_x = torch.zeros((), requires_grad=True)
var_y = torch.zeros((), requires_grad=True)
else:
var_x = torch.zeros(lenna.shape, requires_grad=True)
var_y = torch.zeros(lenna.shape, requires_grad=True)
dh.cpu_arrays.x = lenna
assignments = pystencils_reco.AssignmentCollection({
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center +
pystencils.y_), interpolation_mode='cubic_spline')
}).create_pytorch_op(**dh.gpu_arrays)
pystencils.y_))
})
print(pystencils.autodiff.create_backward_assignments(assignments))
kernel = assignments.create_pytorch_op()
print(kernel.ast)
kernel = kernel().call
learning_rate = 1e-4
params = (dh.cpu_arrays['tx'], dh.cpu_arrays['ty'])
learning_rate = 0.1
params = (var_x, var_y)
# assert all([p.is_leaf for p in params])
optimizer = torch.optim.Adam(params, lr=learning_rate)
for i in range(100):
if scalar_experiment:
dh.cpu_arrays.tx = torch.ones(lenna.shape) * var_x
dh.cpu_arrays.ty = torch.ones(lenna.shape) * var_y
else:
dh.cpu_arrays.tx = var_x
dh.cpu_arrays.ty = var_y
dh.all_to_gpu()
y = dh.run_kernel(kernel)
loss = (y - y_array).norm()
optimizer.zero_grad()
loss.backward()
loss.backward(retain_graph=True)
assert y.requires_grad
optimizer.step()
print(loss.cpu().detach().numpy())
pyconrad.imshow(y)
print("var_x: " + str(var_x.mean()))
pyconrad.imshow(var_x)
# pyconrad.imshow(dh.gpu_arrays)
pyconrad.imshow(dh.gpu_arrays, wait_window_close=True)
@pytest.mark.parametrize('with_spline', ('with_spline', False))
def test_spline_diff(with_spline):
from pystencils.fd import Diff
from pystencils.datahandling import SerialDataHandling
lenna_file = join(dirname(__file__), "test_data", "lenna.png")
lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32)
dh = SerialDataHandling(lenna.shape, default_target='gpu', default_ghost_layers=0, default_layout='numpy')
x, y, tx, ty = dh.add_arrays('x, y, tx, ty', dtype=np.float32)
dh.cpu_arrays['x'] = lenna
dh.cpu_arrays['tx'][...] = 0.7
dh.cpu_arrays['ty'][...] = -0.7
out = dh.add_array('out', dtype=np.float32)
dh.all_to_gpu()
kernel = pystencils_reco.AssignmentCollection({
y.center: Diff(x, 0).interpolated_access((tx.center + pystencils.x_,
ty.center + pystencils.y_),
interpolation_mode='cubic_spline' if with_spline else 'linear')
}).compile(target='gpu')
dh.run_kernel(kernel)
print(pystencils.show_code(kernel))
kernel = pystencils_reco.AssignmentCollection({
out.center: x.interpolated_access((tx.center + pystencils.x_, ty.center + pystencils.y_),
interpolation_mode='cubic_spline' if with_spline else 'linear')
}).compile(target='gpu')
dh.run_kernel(kernel)
print(pystencils.show_code(kernel))
pyconrad.imshow(dh.gpu_arrays)
pyconrad.imshow(dh.gpu_arrays)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment