From c97ffcff464fad4aa12a86b99a75d491071cd575 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 19 Oct 2022 15:39:12 -0400 Subject: [PATCH] [discussion] fix for aot autograd outputs that dont require grad (#86838) Fixes https://github.com/pytorch/functorch/issues/1052 I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`. A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward. This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86838 Approved by: https://github.com/ezyang --- functorch/_src/aot_autograd.py | 14 ++++++++++++++ test/functorch/test_aotdispatch.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index de8a00c68f6da..b1e29b6ac4103 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -369,6 +369,13 @@ def add_dupe_args(args): joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args))) out = flat_fn(*flat_args) + # Collect info on which output tensors require gradients, + # so we can mark them properly in the returned autograd.Function + _flat_outs_not_requiring_grad, _ = pytree.tree_flatten( + pytree.tree_map( + lambda x: isinstance(x, Tensor) and not x.requires_grad, out + ) + ) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out, @@ -435,6 +442,7 @@ class CompiledFunction(torch.autograd.Function): compiled_bw = None num_outs = _num_outs num_symints = _num_symints + flat_outs_not_requiring_grad = _flat_outs_not_requiring_grad @staticmethod @disable_torchdynamo @@ -451,6 +459,12 @@ def forward(ctx, *deduped_flat_tensor_args): else: ctx.save_for_backward(*fw_outs[num_outs:]) ctx.symints = [] + + fw_outs_not_requiring_grad = [ + x for (i, x) in enumerate(fw_outs[0:num_outs]) if CompiledFunction.flat_outs_not_requiring_grad[i] + ] + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + return tuple(fw_outs[0:num_outs]) @staticmethod diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index b211805442b40..e9a46b0882e2e 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -253,6 +253,13 @@ def verify_aot_autograd(self, f, inp): self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad) + if isinstance(ref_out, torch.Tensor): + self.assertTrue(isinstance(test_out, torch.Tensor)) + ref_out, test_out = [ref_out], [test_out] + for ref_o, test_o in zip(ref_out, test_out): + if isinstance(ref_o, torch.Tensor): + self.assertEqual(ref_o.requires_grad, test_o.requires_grad) + def test_single_output(self): def f(a, b): return a + b @@ -280,6 +287,12 @@ def f(a, b): inps = [i() for i in inps] self.verify_aot_autograd(f, inps) + def test_some_outputs_dont_require_grad(self): + def f(a, b): + return a.detach(), b + inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] + self.verify_aot_autograd(f, inp) + def test_inner_grad(self): def foo(x): y = torch.exp(x)