Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic discrepancy on requires_grad after compiling Tensor.detach #1052

Closed
sangongs opened this issue Oct 12, 2022 · 14 comments
Closed

Semantic discrepancy on requires_grad after compiling Tensor.detach #1052

sangongs opened this issue Oct 12, 2022 · 14 comments

Comments

@sangongs
Copy link

Reproduce:

import torch
from functorch.compile import aot_function

def fn(x):
    return x.detach()

aot_fn = aot_function(fn, fw_compiler=lambda fx_module, _: fx_module)

x = torch.randn(1, requires_grad=True)
ref = fn(x)
res = aot_fn(x)

assert(ref.requires_grad == res.requires_grad)

PyTorch version: 1.13.0.dev20220929+cu116

Not sure if this is related to #376.

@samdow
Copy link
Contributor

samdow commented Oct 12, 2022

cc @bdhirsh might be totally off but is this related to any of the work that you were doing to make requires_grad track correctly on proxies?

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 12, 2022

Hmm I don't think so. It looks like it's because we're compiling the whole thing (including the detach() call) into an autograd.Function, and autograd.Function will unconditionally mark all of its forward outputs as requiring gradients.

@albanD brought up a good point - in aot autograd, we already run the forward once to get the expected output(s) to pass to the joint graph for tracing, and that point we should know the expected requires-gradness of every forward output. We can use autograd.function's mark-nondifferentiable API, to (statically) mark those outputs as not requiring gradients, which would fix this problem.

That would technically make the autograd.Function() that we create do the wrong thing if you re-used it with inputs that have different values set for .requires_grad. But we're hiding behind dynamo, and dynamo already specializes on requires_grad-ness today, so we can expect to trace out a new autograd.Function object whenever tht happens.

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 12, 2022

Here's a potential fix based on my discussion with Alban: pytorch/pytorch#86838

@sangongs
Copy link
Author

sangongs commented Oct 13, 2022

Thanks to @bdhirsh for the quick fix. However, the following program still fails after cherry-picking the PR:

import torch
from functorch.compile import aot_function, make_boxed_func
from torchinductor.compile_fx import compile_fx_inner

def fn(x):
    y = x.view(-1).detach()
    return y

aot_fn = aot_function(fn, fw_compiler=compile_fx_inner)


x = torch.randn(1, 2, requires_grad=True)
ref = fn(x)
res = aot_fn(x)

assert(ref.requires_grad == res.requires_grad)

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 13, 2022

It looks like that repro runs ok with the aot-eager backend, but not with inductor:

import torch
from functorch.compile import aot_function, make_boxed_func, not
from torchinductor.compile_fx import compile_fx_inner

def fn(x):
    y = x.view(-1).detach()
    return y

aot_fn = aot_function(fn, fw_compiler=nop)


x = torch.randn(1, 2, requires_grad=True)
ref = fn(x)
res = aot_fn(x)

assert(ref.requires_grad == res.requires_grad)

When I print the output of inductor's codegen, I get:

def call(args):
    primals_1, = args
    args.clear()
    primals_1_size = primals_1.size()
    s0 = primals_1_size[1]
    return (as_strided(primals_1, (s0, ), (1, )), )

Where it looks like inductor is treating the .detach() call(s) in the original graph as no-ops.

To be fair, it seems fair to argue that inductor shouldn't have to worry about requires_grad when it's compiling? I'm not exactly sure what the fix is though. It looks like even though we're calling mark_non_differentiable() on the outputs in aot autograd, they're still being set with requires_grad=True.

@sangongs
Copy link
Author

sangongs commented Oct 13, 2022

Where it looks like inductor is treating the .detach() call(s) in the original graph as no-ops.

Yes, Inductor treats .detach() calls as no-ops:
https://github.com/pytorch/torchdynamo/blob/986da5a19055e99901220ffdc18b80558b54aa7b/torchinductor/lowering.py#L467-L469

To be fair, it seems fair to argue that inductor shouldn't have to worry about requires_grad when it's compiling?

Agree. It will be good if AOT autograd can handle this automatically. Although, in theory, Inductor could generate code to handle detach.

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 13, 2022

@albanD Does this sound like correct autograd.Function behavior? Based on the docs, I would have expected any tensors marked with mark_non_differentiable in the forward as having requires_grad=False.

If that sounds like incorrect behavior, I can dig into autograd.Function a bit more. Alternatively, if this is a limitation of autograd.Function then our options are probably either to move on to something else (like what Richard has brought up before), or properly handle requires_grad-ness in inductor. Here's my example:

class CompiledFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, a):
        out = torch.as_strided(a, a.shape, a.stride())
        # Explicitly mark the output as being non-differentiable
        # EVEN IF it appears to require gradients.
        ctx.mark_non_differentiable(out)
        return out

    @staticmethod
    def backward(ctx, *flat_args):
        # ignore, not called
        return tuple(flat_args)

a = torch.ones(2, 2, requires_grad=True)
b = CompiledFunction.apply(a)
# Prints true. Even though we marked it as non-differentiable?
print(b.requires_grad)

@sangongs
Copy link
Author

I guess I found the reason why mark_non_differentiable does not unset requires_grad in this case. It might be because of this piece of code:
https://github.com/pytorch/pytorch/blob/ae45dab57e22e3d04516e7dd81ef8dbefd51bfe3/torch/csrc/autograd/custom_function.cpp#L290-L299

Basically, if the output is a view, then mark_non_differentiable() takes no effect on it.

Maybe a stupid question. But can we just apply .detach() onto non-differentiable outputs instead of mark_non_differentiable()?

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 13, 2022

@sangongs nice catch. I'll defer to Alban, but... I think that sees reasonable (it feels bad, because it looks like we're trying to ignore autograd.Function's existing behavior, but the only reason for that is because there was a .detach() in the original graph that the compiler removed, so... we're adding it back).

@sangongs
Copy link
Author

I came up with a work-around in Inductor to deal with this special tensor.view().detach() case: pytorch/torchdynamo#1661

@albanD
Copy link
Contributor

albanD commented Oct 16, 2022

This is a good catch.
In general, indeed, setting requires_grad on a differentiable view has no effect as it's t.requires_grad field's value is set to reflect its base's requires_grad-ness.

In this case, if the user explicitely state that this is not differentiable, then we should properly detach as it is a non-differentiable view.
cc @soulitzer I think this is something we want to solve on the custom Function side.

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 17, 2022

@albanD to confirm - you think that this is something that should be handled transparently by autograd.function?

aka if one of the tensors that the user marks with ctx.set_non_differentiable(...) is a differentiable view, autograd.function should implicitly .detach() it?

@albanD
Copy link
Contributor

albanD commented Oct 17, 2022

you think that this is something that should be handled transparently by autograd.function?

Yes

ezyang added a commit to pytorch/pytorch that referenced this issue Oct 19, 2022
…that dont require grad"

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.





[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this issue Oct 19, 2022
…e grad"

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.





[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this issue Oct 19, 2022
…that dont require grad"

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.





[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this issue Oct 19, 2022
…e grad"

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.





[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this issue Oct 19, 2022
…that dont require grad"

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.





[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this issue Oct 19, 2022
…e grad"

Fixes pytorch/functorch#1052

I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.

A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.

This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.





[ghstack-poisoned]
@sangongs
Copy link
Author

sangongs commented Nov 7, 2022

Looks like the issue is still not fixed for backends like inductor that do not handle detach(). @bdhirsh Do you have plan to implement this:

aka if one of the tensors that the user marks with ctx.set_non_differentiable(...) is a differentiable view, autograd.function should implicitly .detach() it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants