Commit 0ab6f7b0 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement downsample

parent 9218f117
Pipeline #20639 failed with stage
in 60 minutes and 1 second
......@@ -8,6 +8,7 @@
Implements common resampling operations like rotations and scalings
"""
import itertools
import types
from collections.abc import Iterable
......@@ -125,3 +126,21 @@ def translate(input_field: pystencils.Field,
output_field.center: input_field.interpolated_access(
input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
}
@crazy
def downsample(input: {'field_type': pystencils.field.FieldType.CUSTOM},
result,
factor):
assert input.spatial_dimensions == result.spatial_dimensions
assert input.index_shape == result.index_shape
assignments = []
ndim = input.spatial_dimensions
assignments.append(
pystencils.Assignment(result.center,
input.absolute_access(factor * pystencils.x_vector(ndim), ()))
)
return assignments
......@@ -28,8 +28,28 @@ def relu(input, result):
@crazy
def max_pooling(input: {'index_dimensions': 1, 'field_type': FieldType.CUSTOM},
result: {'index_dimensions': 1}):
def max_pooling(input: {'field_type': FieldType.CUSTOM},
result):
assert input.spatial_dimensions == result.spatial_dimensions
assert input.index_shape == result.index_shape
assignments = []
ndim = input.spatial_dimensions
offsets = itertools.product((0, 1), repeat=ndim)
assignments.append(
pystencils.Assignment(result.center,
sympy.Max(*[input.absolute_access(2 * pystencils.x_vector(ndim) + sympy.Matrix(offset), ()) # noqa
for offset in offsets])
)
)
return assignments
@crazy
def max_pooling_channels(input: {'index_dimensions': 1, 'field_type': FieldType.CUSTOM},
result: {'index_dimensions': 1}):
assert input.spatial_dimensions == result.spatial_dimensions
assert input.index_shape == result.index_shape
assignments = []
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import numpy as np
import pystencils
from pystencils_reco.resampling import downsample, scale_transform
from pystencils_reco.unet import max_pooling
try:
import pyconrad.autoinit
except Exception:
import unittest.mock
pyconrad = unittest.mock.MagicMock()
def test_superresolution():
x, y = np.random.rand(20, 10), np.zeros((20, 10))
kernel = scale_transform(x, y, 0.5).compile()
print(pystencils.show_code(kernel))
kernel()
pyconrad.show_everything()
def test_downsample():
shape = (20, 10)
x, y = np.random.rand(*shape), np.zeros(shape)
kernel = downsample(x, y, 2).compile()
print(pystencils.show_code(kernel))
kernel()
pyconrad.show_everything()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment