Commit a9797544 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make test_tfmad_gradient_check_torch_native also pass with CUDA

with_cuda needed to be True not 'with_cuda'
parent 0e367089
......@@ -231,7 +231,7 @@ def test_tfmad_gradient_check_torch_native(with_offsets, with_cuda):
[dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
@pytest.mark.parametrize('with_cuda', (False, 'with_cuda'))
@pytest.mark.parametrize('with_cuda', (False, True))
def test_tfmad_gradient_check_two_outputs(with_cuda):
torch = pytest.importorskip('torch')
import torch
......@@ -274,9 +274,9 @@ def test_tfmad_gradient_check_two_outputs(with_cuda):
dict = {
a: a_tensor,
b: b_tensor,
out1_tensor: out1_tensor,
out2_tensor: out2_tensor,
out3_tensor: out3_tensor,
out1: out1_tensor,
out2: out2_tensor,
out3: out3_tensor,
}
torch.autograd.gradcheck(function.apply, tuple(
[dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment