Commit 1e141dfb authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Change torch native for new interface

parent 75dd38f2
Pipeline #20141 failed with stage
in 1 minute and 39 seconds
......@@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
grad_outputs = [a.contiguous().cuda() for a in grad_outputs]
else:
grad_outputs = [a.contiguous().cpu() for a in grad_outputs]
gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
gradients = {f.name: grad_outputs[i] for i, f in enumerate(grad_fields)}
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(autodiff_obj.backward_input_fields))
for i, f in enumerate(grad_fields))
assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
f"Op was compiled for CUDA: {str(use_cuda)}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment