Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] update the condition for aliveness of TensorWrapper #98748

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Apr 10, 2023

Fixes #95561
Fixes #98021

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98748

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b74e160:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kshitij12345 kshitij12345 marked this pull request as ready for review April 10, 2023 15:36
@kshitij12345 kshitij12345 added the release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch label Apr 10, 2023
@@ -3204,13 +3204,9 @@ def f(x):

B = 5
x = torch.randn(B, 3)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented for aten::_make_dual"):
Copy link
Collaborator Author

@kshitij12345 kshitij12345 Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added regex to match for clarity of expected failure.

vmap(f)(x)

x = torch.randn([])
with self.assertRaises(RuntimeError):
grad(f)(x)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test used to fail due to this issue :

import torch
from torch.func import jacfwd, jacrev, vmap, vjp, jvp, grad
from functorch import make_fx
from torch._C._functorch import unwrap_if_dead, is_dead_tensor_wrapper
from torch.autograd.forward_ad import make_dual, dual_level

torch.manual_seed(420)

x = torch.randn(())

def f(x):
    y = torch.autograd.functional.jacobian(
                lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
    return y

grad(f)(x)

Output:

RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "/home/kshiteej/Pytorch/pytorch_functorch/aten/src/ATen/functorch/TensorWrapper.cpp":213, please report a bug to PyTorch. Should have at least one dead wrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm kind of confused, How does test_autograd_functional_jvp_inside_transform succeed (in that grad(f)(x) raises a RuntimeError) but in test_autograd_functional_jacfwd_inside_transform the RuntimeError isn't raised?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you know why this succeeds now? We are not creating a new torch.tensor using a list of tensors, so it sounds like there is some interesting interaction going on

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to be happening because of a bad interaction with legacy _vmap. If I update it to torch.func.vmap, the code fails with Batching rule not implemented for make_dual.

outputs_before_split = _vmap(jvp)(tangents)

Also, this doesn't really succeed in a workable fashion. In the following script, it returns BatchedTensor which fails if you try to print it or compare it.

import torch
from torch.func import jacfwd, jacrev, vmap, vjp, jvp, grad
from functorch import make_fx
from torch._C._functorch import unwrap_if_dead, is_dead_tensor_wrapper
from torch.autograd.forward_ad import make_dual, dual_level

torch.manual_seed(420)

x = torch.randn(())

def f(x):
    y = torch.autograd.functional.jacobian(
                lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
    return y

def f_exp(x):
    y = jacrev(lambda x: x.sin().sum())(x)
    return y

j = jacrev(f)(x)
# print(j)  # RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.

expected_j = jacrev(f_exp)(x)
print(expected_j)  # Works

torch.testing.assert_close(j, expected_j)  # RuntimeError: Batching rule not implemented for aten::is_nonzero. We could not generate a fallback.

I don't think legacy vmap is expected to work with functorch transforms, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does test_autograd_functional_jvp_inside_transform succeed (in that grad(f)(x) raises a RuntimeError) but in test_autograd_functional_jacfwd_inside_transform the RuntimeError isn't raised?

Both have a different approach to computation of jvp.

Under jacobian:

# Step 2: Compute vmap over computation with dual tensors
def jvp(tangents):
with fwAD.dual_level():
dual_inputs = tuple(
fwAD.make_dual(input, tangent.view_as(input)) for input, tangent in zip(inputs, tangents))
_is_outputs_tuple, dual_outputs = _as_tuple(func(*dual_inputs), "outputs")
output_info.append(_is_outputs_tuple)
jv = []
primal_outs = []
for dual_out in dual_outputs:
primal, tangent = fwAD.unpack_dual(dual_out)
primal_outs.append(primal)
if tangent is not None:
jv.append(tangent)
else:
jv.append(torch.zeros_like(primal))
output_info.append(primal_outs)
return tuple(jv)
outputs_before_split = _vmap(jvp)(tangents)

Under jvp: (this one hits a different failure - "You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform.")

with torch.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
if v is not None:
_, v = _as_tuple(v, "v", "jvp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, inputs, is_inputs_tuple)
else:
if len(inputs) != 1 or inputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the input to "
"the user-provided function is a single Tensor "
"with a single element.")
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp")
_check_requires_grad(outputs, "outputs", strict=strict)
# The backward is linear so the value of grad_outputs is not important as
# it won't appear in the double backward graph. We only need to ensure that
# it does not contain inf or nan.
grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs)
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
if create_graph:
with torch.enable_grad():
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
else:
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
# Cleanup objects and return them to the user
outputs = _grad_postprocess(outputs, create_graph)
jvp = _grad_postprocess(jvp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(jvp, is_outputs_tuple)

@kshitij12345
Copy link
Collaborator Author

Seems like test_autograd_functional_jacfwd_inside_transform fails with accidental errors rather than intentional one. (Have added comments to the relevant line of code).

def test_autograd_functional_jacfwd_inside_transform(self, device):
def f(x):
y = torch.autograd.functional.jacobian(
lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
return y
B = 5
x = torch.randn(B, 3)
with self.assertRaises(RuntimeError):
vmap(f)(x)
x = torch.randn([])
with self.assertRaises(RuntimeError):
grad(f)(x)

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 12, 2023
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for clarifying

@kshitij12345
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 13, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged merging open source release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
5 participants