Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] jit.trace module methods / module as an input #18569

Closed
ssnl opened this issue Mar 28, 2019 · 8 comments
Closed

[FR] jit.trace module methods / module as an input #18569

ssnl opened this issue Mar 28, 2019 · 8 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@ssnl
Copy link
Collaborator

ssnl commented Mar 28, 2019

Context: #17583

We killed support for tracing functions that reference tensor(s) requiring grad. This effectively makes it impossible to trace a function that computes on model parameters (e.g., compute the spectral norms of all conv weights in a CNN), unless we represent the function as another module. Yet using another module isn't always ideal, because then the two modules would need to share parameters (uhh).

An alternative workaround is to just put every parameter as the traced function's input, and then use another wrapper on top of the returned callable. This results in rather ugly and unreadable code.

So here is my FR: supporting either

  • Tracing an nn.Module method, or
  • Tracing with an nn.Module as an input

Really these two can maybe be viewed as the same thing because a method is just a bounded function with an input being the nn.Module object, i.e., self.

Here is the proposed API:

class Network(nn.Module):
  def __init__(self):
    self.conv = nn.Conv(3, 3, 3, 1)

  def forward(self, x):
    return self.conv(x).relu_()

  @jit.trace
  def estimate_spectral_norm(self, x):
    w = self.conv.weight
    # power iteration or really solve it or whatever
    ...
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 28, 2019
@suo
Copy link
Member

suo commented Apr 1, 2019

@zdevito this would not be so hard if modules were 1st class, right?

@zdevito
Copy link
Contributor

zdevito commented Apr 1, 2019

@ssnl I am not sure I understand the issue right, can you post an example failure. I want to make sure there isn't something we can do now.

@ssnl
Copy link
Collaborator Author

ssnl commented Apr 1, 2019

@zdevito

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 3)

        # Want to trace self.weighted_kernel_sum

        # Method 1:
        #   torch.jit.trace(self.weighted_kernel_sum, (torch.randn(3, 3, 3, 3),))
        # Error:    
        #   RuntimeError: Cannot insert a Tensor that requires grad as a constant. 
        #                 Consider making it a parameter or input, or detaching the gradient

        # Method 2:
        #
        #   torch.jit.trace(Net.weighted_kernel_sum, (self, torch.randn(3, 3, 3, 3),))
        # Error:
        #   RuntimeError: Only tensors and (possibly nested) tuples of tensors 
        #   are supported as inputs or outputs of traced functions, but instead 
        #   got value of type Net.
        #   Value: Net(
        #     (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
        #   )

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):  # I want to trace this thing
        return (weight * self.conv.weight).sum()


Net()

@ssnl
Copy link
Collaborator Author

ssnl commented Apr 1, 2019

Method 1 used to work, until some time around January.

@ssnl
Copy link
Collaborator Author

ssnl commented Apr 9, 2019

@zdevito Does the above example make sense to you? :)

@zdevito
Copy link
Contributor

zdevito commented Apr 9, 2019

Oh, sorry. I missed the reply a week ago :( Yes, the example makes sense. I think I understand the problem now. We can make this work.

@zdevito
Copy link
Contributor

zdevito commented Apr 9, 2019

Here is my proposed fix: #19070

@suo
Copy link
Member

suo commented May 3, 2019

Closing this as it's been folded into #19070. Feel free to reopen if you think the proposed fix is not enough

@suo suo closed this as completed May 3, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants