Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
pycodegen
pystencils_autodiff
Commits
1e141dfb
Commit
1e141dfb
authored
Nov 29, 2019
by
Stephan Seitz
Browse files
Change torch native for new interface
parent
75dd38f2
Pipeline
#20141
failed with stage
in 1 minute and 39 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/pystencils_autodiff/backends/_torch_native.py
View file @
1e141dfb
...
...
@@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
grad_outputs
=
[
a
.
contiguous
().
cuda
()
for
a
in
grad_outputs
]
else
:
grad_outputs
=
[
a
.
contiguous
().
cpu
()
for
a
in
grad_outputs
]
gradients
=
{
f
.
name
:
grad_outputs
[
i
]
for
i
,
f
in
enumerate
(
autodiff_obj
.
backward_input_fields
)}
assert
all
(
f
.
shape
==
grad_outputs
[
i
].
shape
for
i
,
f
in
enumerate
(
autodiff_obj
.
backward_input_fields
))
grad_fields
=
[
f
for
f
in
autodiff_obj
.
backward_input_fields
if
f
not
in
autodiff_obj
.
forward_input_fields
]
gradients
=
{
f
.
name
:
grad_outputs
[
i
]
for
i
,
f
in
enumerate
(
grad_fields
)}
assert
all
(
f
.
shape
==
grad_outputs
[
i
].
shape
for
i
,
f
in
enumerate
(
grad_fields
))
assert
all
(
f
.
strides
==
tuple
(
grad_outputs
[
i
].
stride
(
j
)
for
j
in
range
(
grad_outputs
[
i
].
ndim
))
for
i
,
f
in
enumerate
(
autodiff_obj
.
backward_input
_fields
))
for
i
,
f
in
enumerate
(
grad
_fields
))
assert
all
(
a
.
is_cuda
==
use_cuda
for
a
in
grad_outputs
),
"Some of the tensors where on the wrong device. "
f
"Op was compiled for CUDA:
{
str
(
use_cuda
)
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment