Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stephan Seitz
pystencil_reco
Commits
c7f0ad48
Commit
c7f0ad48
authored
Mar 02, 2020
by
Stephan Seitz
Browse files
Extend test_homography to support also constant H
parent
585ac16b
Pipeline
#22395
failed with stage
in 60 minutes and 1 second
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
tests/test_superresolution.py
View file @
c7f0ad48
...
...
@@ -16,7 +16,6 @@ import sympy
import
pystencils
import
pystencils_reco.transforms
from
pystencils.data_types
import
create_type
from
pystencils_reco
import
crazy
from
pystencils_reco._projective_matrix
import
ProjectiveMatrix
from
pystencils_reco.filters
import
gauss_filter
...
...
@@ -42,7 +41,8 @@ def test_superresolution():
pyconrad
.
show_everything
()
def
test_torch_simple
():
@
pytest
.
mark
.
parametrize
(
'constant_h'
,
(
'constant_h'
,
False
))
def
test_torch_simple
(
constant_h
):
import
pytest
pytest
.
importorskip
(
"torch"
)
...
...
@@ -50,18 +50,21 @@ def test_torch_simple():
x
,
y
=
pystencils
.
fields
(
'x,y: float32[2d]'
)
h
=
pystencils
.
fields
(
'h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]'
)
@
crazy
def
move
(
x
,
y
):
h
=
pystencils
.
fields
(
'h(8): float32[2d]'
)
A
=
sympy
.
Matrix
([[
h
.
center
(
0
),
h
.
center
(
1
),
h
.
center
(
2
)],
[
h
.
center
(
3
),
h
.
center
(
4
),
h
.
center
(
5
)],
[
h
.
center
(
6
),
h
.
center
(
7
),
1
]])
A
=
sympy
.
Matrix
([[
h
[
0
].
center
,
h
[
1
].
center
,
h
[
2
].
center
],
[
h
[
3
].
center
,
h
[
4
].
center
,
h
[
5
].
center
],
[
h
[
6
].
center
,
h
[
7
].
center
,
1
]])
return
{
y
.
center
:
x
.
interpolated_access
(
ProjectiveMatrix
(
A
)
@
pystencils
.
x_vector
(
2
))
}
kernel
=
move
(
x
,
y
).
create_pytorch_op
()
if
constant_h
:
kernel
=
move
(
x
,
y
).
create_pytorch_op
(
constant_fields
=
h
)
else
:
kernel
=
move
(
x
,
y
).
create_pytorch_op
()
pystencils
.
autodiff
.
show_code
(
kernel
.
ast
)
x
=
torch
.
ones
((
10
,
40
)).
cuda
()
...
...
@@ -83,11 +86,6 @@ def test_torch_simple():
def
test_torch_matrix
():
import
pytest
pytest
.
importorskip
(
"torch"
)
import
torch
# x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
x
,
y
=
pystencils
.
fields
(
'x,y: float32[2d]'
)
a
=
sympy
.
Symbol
(
'a'
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment