Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[discussion] fix for aot autograd outputs that dont require grad (pyt…
…orch#86838) Fixes pytorch/functorch#1052 I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`. A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward. This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing. Pull Request resolved: pytorch#86838 Approved by: https://github.com/ezyang
- Loading branch information