-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] update the condition for aliveness of TensorWrapper #98748
[fix] update the condition for aliveness of TensorWrapper #98748
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98748
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b74e160: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -3204,13 +3204,9 @@ def f(x): | |||
|
|||
B = 5 | |||
x = torch.randn(B, 3) | |||
with self.assertRaises(RuntimeError): | |||
with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented for aten::_make_dual"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added regex to match for clarity of expected failure.
vmap(f)(x) | ||
|
||
x = torch.randn([]) | ||
with self.assertRaises(RuntimeError): | ||
grad(f)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test used to fail due to this issue :
import torch
from torch.func import jacfwd, jacrev, vmap, vjp, jvp, grad
from functorch import make_fx
from torch._C._functorch import unwrap_if_dead, is_dead_tensor_wrapper
from torch.autograd.forward_ad import make_dual, dual_level
torch.manual_seed(420)
x = torch.randn(())
def f(x):
y = torch.autograd.functional.jacobian(
lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
return y
grad(f)(x)
Output:
RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "/home/kshiteej/Pytorch/pytorch_functorch/aten/src/ATen/functorch/TensorWrapper.cpp":213, please report a bug to PyTorch. Should have at least one dead wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm kind of confused, How does test_autograd_functional_jvp_inside_transform
succeed (in that grad(f)(x) raises a RuntimeError) but in test_autograd_functional_jacfwd_inside_transform
the RuntimeError isn't raised?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do you know why this succeeds now? We are not creating a new torch.tensor using a list of tensors, so it sounds like there is some interesting interaction going on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to be happening because of a bad interaction with legacy _vmap
. If I update it to torch.func.vmap
, the code fails with Batching rule not implemented for make_dual
.
pytorch/torch/autograd/functional.py
Lines 476 to 477 in c377a85
outputs_before_split = _vmap(jvp)(tangents) |
Also, this doesn't really succeed in a workable fashion. In the following script, it returns BatchedTensor which fails if you try to print it or compare it.
import torch
from torch.func import jacfwd, jacrev, vmap, vjp, jvp, grad
from functorch import make_fx
from torch._C._functorch import unwrap_if_dead, is_dead_tensor_wrapper
from torch.autograd.forward_ad import make_dual, dual_level
torch.manual_seed(420)
x = torch.randn(())
def f(x):
y = torch.autograd.functional.jacobian(
lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
return y
def f_exp(x):
y = jacrev(lambda x: x.sin().sum())(x)
return y
j = jacrev(f)(x)
# print(j) # RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
expected_j = jacrev(f_exp)(x)
print(expected_j) # Works
torch.testing.assert_close(j, expected_j) # RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.
I don't think legacy vmap is expected to work with functorch transforms, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does test_autograd_functional_jvp_inside_transform succeed (in that grad(f)(x) raises a RuntimeError) but in test_autograd_functional_jacfwd_inside_transform the RuntimeError isn't raised?
Both have a different approach to computation of jvp
.
Under jacobian:
pytorch/torch/autograd/functional.py
Lines 458 to 477 in c377a85
# Step 2: Compute vmap over computation with dual tensors | |
def jvp(tangents): | |
with fwAD.dual_level(): | |
dual_inputs = tuple( | |
fwAD.make_dual(input, tangent.view_as(input)) for input, tangent in zip(inputs, tangents)) | |
_is_outputs_tuple, dual_outputs = _as_tuple(func(*dual_inputs), "outputs") | |
output_info.append(_is_outputs_tuple) | |
jv = [] | |
primal_outs = [] | |
for dual_out in dual_outputs: | |
primal, tangent = fwAD.unpack_dual(dual_out) | |
primal_outs.append(primal) | |
if tangent is not None: | |
jv.append(tangent) | |
else: | |
jv.append(torch.zeros_like(primal)) | |
output_info.append(primal_outs) | |
return tuple(jv) | |
outputs_before_split = _vmap(jvp)(tangents) |
Under jvp: (this one hits a different failure - "You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform.")
pytorch/torch/autograd/functional.py
Lines 371 to 408 in c377a85
with torch.enable_grad(): | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") | |
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) | |
if v is not None: | |
_, v = _as_tuple(v, "v", "jvp") | |
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) | |
_validate_v(v, inputs, is_inputs_tuple) | |
else: | |
if len(inputs) != 1 or inputs[0].nelement() != 1: | |
raise RuntimeError("The vector v can only be None if the input to " | |
"the user-provided function is a single Tensor " | |
"with a single element.") | |
outputs = func(*inputs) | |
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp") | |
_check_requires_grad(outputs, "outputs", strict=strict) | |
# The backward is linear so the value of grad_outputs is not important as | |
# it won't appear in the double backward graph. We only need to ensure that | |
# it does not contain inf or nan. | |
grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs) | |
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) | |
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict) | |
if create_graph: | |
with torch.enable_grad(): | |
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph) | |
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") | |
else: | |
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph) | |
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") | |
# Cleanup objects and return them to the user | |
outputs = _grad_postprocess(outputs, create_graph) | |
jvp = _grad_postprocess(jvp, create_graph) | |
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple) |
Seems like pytorch/test/functorch/test_eager_transforms.py Lines 3199 to 3212 in 69eef5a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for clarifying
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #95561
Fixes #98021