Commit af7bfff2 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Refactor interpolation

parent 189cc59e
Pipeline #21028 failed with stage
in 5 minutes and 4 seconds
......@@ -181,6 +181,9 @@ class AssignmentCollection(pystencils.AssignmentCollection):
if hasattr(t, 'requires_grad') and not t.requires_grad]
constant_fields = {f for f in self.free_fields if f.name in constant_field_names}
for n in [f for f, t in kwargs.items() if hasattr(t, 'requires_grad')]:
kwargs.pop(n)
if not self._autodiff:
if hasattr(self, '_create_autodiff'):
self._create_autodiff(constant_fields, **kwargs)
......
......@@ -259,11 +259,29 @@ def test_spatial_derivative():
tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
assignments = pystencils.AssignmentCollection({
y.center: x.interpolated_access((tx.center, 2 * ty.center))
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_))
})
backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
print("assignments: " + str(assignments))
print("backward_assignments: " + str(backward_assignments))
def test_spatial_derivative2():
import pystencils.interpolation_astnodes
x, y = pystencils.fields('x, y: float32[2d]')
tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
assignments = pystencils.AssignmentCollection({
y.center: x.interpolated_access((tx.center + pystencils.x_, ty.center + 2 * pystencils.y_))
})
backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
assert backward_assignments.atoms(pystencils.interpolation_astnodes.DiffInterpolatorAccess)
print("assignments: " + str(assignments))
print("backward_assignments: " + str(backward_assignments))
......@@ -282,9 +300,70 @@ def test_get_shift():
dh.all_to_gpu()
kernel = pystencils_reco.AssignmentCollection({
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_))
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center +
pystencils.y_), interpolation_mode='cubic_spline')
}).create_pytorch_op()().forward
dh.run_kernel(kernel)
pyconrad.imshow(dh.gpu_arrays)
def test_get_shift_tensors():
from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling
import torch
lenna_file = join(dirname(__file__), "test_data", "lenna.png")
lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32)
dh = PyTorchDataHandling(lenna.shape)
x, y, tx, ty = dh.add_arrays('xw, yw, txw, tyw')
dh.cpu_arrays['xw'] = lenna
dh.cpu_arrays['txw'][...] = 0.7
dh.cpu_arrays['tyw'][...] = -0.7
dh.all_to_gpu()
kernel = pystencils_reco.AssignmentCollection({
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_),
interpolation_mode='cubic_spline')
}).create_pytorch_op()().call
dh.run_kernel(kernel)
y_array = dh.gpu_arrays['yw']
dh = PyTorchDataHandling(lenna.shape)
x, y, tx, ty = dh.add_arrays('x, y, tx, ty')
dh.cpu_arrays['tx'] = torch.zeros(lenna.shape, requires_grad=True)
dh.cpu_arrays['ty'] = torch.zeros(lenna.shape, requires_grad=True)
dh.cpu_arrays['x'] = lenna
dh.all_to_gpu()
kernel = pystencils_reco.AssignmentCollection({
y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center +
pystencils.y_), interpolation_mode='cubic_spline')
}).create_pytorch_op(**dh.gpu_arrays)
print(kernel.ast)
kernel = kernel().call
learning_rate = 1e-4
params = (dh.cpu_arrays['tx'], dh.cpu_arrays['ty'])
# assert all([p.is_leaf for p in params])
optimizer = torch.optim.Adam(params, lr=learning_rate)
for i in range(100):
y = dh.run_kernel(kernel)
loss = (y - y_array).norm()
optimizer.zero_grad()
loss.backward()
assert y.requires_grad
optimizer.step()
print(loss.cpu().detach().numpy())
pyconrad.imshow(y)
pyconrad.imshow(dh.gpu_arrays)
pyconrad.imshow(dh.gpu_arrays)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment