Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
pycodegen
pystencils_autodiff
Commits
25ebe582
Commit
25ebe582
authored
Dec 16, 2019
by
Stephan Seitz
Browse files
Allow use_cuda to be truthy (instead of only True/False)
parent
f3772f74
Pipeline
#20540
failed with stage
in 5 minutes and 41 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/pystencils_autodiff/backends/_torch_native.py
View file @
25ebe582
...
...
@@ -74,9 +74,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
return
tuple
(
output_tensors
.
values
())
def
backward
(
self
,
*
grad_outputs
):
nonlocal
use_cuda
if
use_cuda
:
use_cuda
=
True
grad_outputs
=
[
a
.
contiguous
().
cuda
()
for
a
in
grad_outputs
]
else
:
use_cuda
=
False
grad_outputs
=
[
a
.
contiguous
().
cpu
()
for
a
in
grad_outputs
]
grad_fields
=
[
f
for
f
in
autodiff_obj
.
backward_input_fields
if
f
not
in
autodiff_obj
.
forward_input_fields
]
...
...
@@ -84,7 +87,6 @@ def create_autograd_function(autodiff_obj, use_cuda):
assert
all
(
f
.
shape
==
grad_outputs
[
i
].
shape
for
i
,
f
in
enumerate
(
grad_fields
))
assert
all
(
f
.
strides
==
tuple
(
grad_outputs
[
i
].
stride
(
j
)
for
j
in
range
(
grad_outputs
[
i
].
ndim
))
for
i
,
f
in
enumerate
(
grad_fields
))
assert
use_cuda
in
(
True
,
False
),
"use_cuda needs to be True or False"
assert
all
(
a
.
is_cuda
==
use_cuda
for
a
in
grad_outputs
),
(
"Some of the tensors where on the wrong device. "
f
"Op was compiled for CUDA:
{
str
(
use_cuda
)
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment