Skip to content

Commit

Permalink
[discussion] fix for aot autograd outputs that dont require grad (pyt…
Browse files Browse the repository at this point in the history
…orch#86838)

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.

Pull Request resolved: pytorch#86838
Approved by: https://github.com/ezyang
  • Loading branch information
ezyang authored and pytorchmergebot committed Oct 19, 2022
1 parent c9b6184 commit c97ffcf
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
14 changes: 14 additions & 0 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ def add_dupe_args(args):
joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args)))

out = flat_fn(*flat_args)
# Collect info on which output tensors require gradients,
# so we can mark them properly in the returned autograd.Function
_flat_outs_not_requiring_grad, _ = pytree.tree_flatten(
pytree.tree_map(
lambda x: isinstance(x, Tensor) and not x.requires_grad, out
)
)
out = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
out,
Expand Down Expand Up @@ -435,6 +442,7 @@ class CompiledFunction(torch.autograd.Function):
compiled_bw = None
num_outs = _num_outs
num_symints = _num_symints
flat_outs_not_requiring_grad = _flat_outs_not_requiring_grad

@staticmethod
@disable_torchdynamo
Expand All @@ -451,6 +459,12 @@ def forward(ctx, *deduped_flat_tensor_args):
else:
ctx.save_for_backward(*fw_outs[num_outs:])
ctx.symints = []

fw_outs_not_requiring_grad = [
x for (i, x) in enumerate(fw_outs[0:num_outs]) if CompiledFunction.flat_outs_not_requiring_grad[i]
]
ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)

return tuple(fw_outs[0:num_outs])

@staticmethod
Expand Down
13 changes: 13 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ def verify_aot_autograd(self, f, inp):
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)

if isinstance(ref_out, torch.Tensor):
self.assertTrue(isinstance(test_out, torch.Tensor))
ref_out, test_out = [ref_out], [test_out]
for ref_o, test_o in zip(ref_out, test_out):
if isinstance(ref_o, torch.Tensor):
self.assertEqual(ref_o.requires_grad, test_o.requires_grad)

def test_single_output(self):
def f(a, b):
return a + b
Expand Down Expand Up @@ -280,6 +287,12 @@ def f(a, b):
inps = [i() for i in inps]
self.verify_aot_autograd(f, inps)

def test_some_outputs_dont_require_grad(self):
def f(a, b):
return a.detach(), b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)]
self.verify_aot_autograd(f, inp)

def test_inner_grad(self):
def foo(x):
y = torch.exp(x)
Expand Down

0 comments on commit c97ffcf

Please sign in to comment.