Commit b0635d2c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement upsample/downsample and adjoint for translate

parent 3c0a92e1
Pipeline #20677 failed with stage
in 1 minute and 24 seconds
......@@ -8,7 +8,6 @@
Implements common resampling operations like rotations and scalings
"""
import itertools
import types
from collections.abc import Iterable
......@@ -17,6 +16,7 @@ import sympy
import pystencils
import pystencils.autodiff
from pystencils.autodiff import AdjointField
from pystencils.data_types import cast_func, create_type
from pystencils_reco import AssignmentCollection, crazy
......@@ -122,10 +122,46 @@ def translate(input_field: pystencils.Field,
translation,
interpolation_mode='linear'):
return {
output_field.center: input_field.interpolated_access(
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
}
def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs)
if isinstance(translation, pystencils.Field):
translation = translation.center_vector
assignments = AssignmentCollection(
{
output_field.center: input_field.interpolated_access(
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
})
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
@crazy
def upsample(input: {'field_type': pystencils.field.FieldType.CUSTOM},
result,
factor):
ndim = input.spatial_dimensions
here = pystencils.x_vector(ndim)
assignments = AssignmentCollection(
{result.center:
pystencils.astnodes.ConditionalFieldAccess(
input.absolute_access(tuple(cast_func(sympy.S(1) / factor * h,
create_type('int64')) for h in here), ()),
sympy.Or(*[s % cast_func(factor, 'int64') > 0 for s in here]))
})
def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = downsample(AdjointField(result), AdjointField(input), factor)
self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs)
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
@crazy
......@@ -137,5 +173,14 @@ def downsample(input: {'field_type': pystencils.field.FieldType.CUSTOM},
ndim = input.spatial_dimensions
return {result.center,
input.absolute_access(factor * pystencils.x_vector(ndim), ())}
assignments = AssignmentCollection({result.center:
input.absolute_access(factor * pystencils.x_vector(ndim), ())})
def create_autodiff(self, constant_fields=None, **kwargs):
backward_assignments = upsample(AdjointField(result), AdjointField(input), factor)
self._autodiff = pystencils.autodiff.AutoDiffOp(
assignments, "", backward_assignments=backward_assignments, **kwargs)
assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
return assignments
......@@ -7,11 +7,13 @@
"""
"""
from os.path import dirname, join
import numpy as np
import skimage.io
import pystencils
from pystencils_reco.resampling import downsample, scale_transform
from pystencils_reco.unet import max_pooling
from pystencils_reco.resampling import downsample, scale_transform, translate, upsample
try:
import pyconrad.autoinit
......@@ -22,7 +24,7 @@ except Exception:
def test_superresolution():
x, y = np.random.rand(20, 10), np.zeros((20, 10))
x, y = np.random.rand(20, 10), np.zeros((20, 10))
kernel = scale_transform(x, y, 0.5).compile()
print(pystencils.show_code(kernel))
......@@ -34,10 +36,32 @@ def test_superresolution():
def test_downsample():
shape = (20, 10)
x, y = np.random.rand(*shape), np.zeros(shape)
x, y = np.random.rand(*shape), np.zeros(tuple(s // 2 for s in shape))
kernel = downsample(x, y, 2).compile()
print(pystencils.show_code(kernel))
kernel()
pyconrad.show_everything()
def test_warp():
import torch
NUM_LENNAS = 5
perturbation = 0.1
lenna_file = join(dirname(__file__), "test_data", "lenna.png")
lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32)
warp_vectors = list(perturbation * torch.randn(lenna.shape + (2,)) for _ in range(NUM_LENNAS))
warped = [torch.zeros(lenna.shape) for _ in range(NUM_LENNAS)]
warp_kernel = translate(lenna, warped[0], pystencils.autodiff.ArrayWrapper(
warp_vectors[0], index_dimensions=1), interpolation_mode='linear').compile()
for i in range(len(warped)):
warp_kernel(lenna[i], warped[i], warp_vectors[i])
test_warp()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment