-
Notifications
You must be signed in to change notification settings - Fork 22.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FR] jit.trace module methods / module as an input #18569
Comments
@zdevito this would not be so hard if modules were 1st class, right? |
@ssnl I am not sure I understand the issue right, can you post an example failure. I want to make sure there isn't something we can do now. |
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3)
# Want to trace self.weighted_kernel_sum
# Method 1:
# torch.jit.trace(self.weighted_kernel_sum, (torch.randn(3, 3, 3, 3),))
# Error:
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
# Method 2:
#
# torch.jit.trace(Net.weighted_kernel_sum, (self, torch.randn(3, 3, 3, 3),))
# Error:
# RuntimeError: Only tensors and (possibly nested) tuples of tensors
# are supported as inputs or outputs of traced functions, but instead
# got value of type Net.
# Value: Net(
# (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# )
def forward(self, x):
return self.conv(x)
def weighted_kernel_sum(self, weight): # I want to trace this thing
return (weight * self.conv.weight).sum()
Net() |
Method 1 used to work, until some time around January. |
@zdevito Does the above example make sense to you? :) |
Oh, sorry. I missed the reply a week ago :( Yes, the example makes sense. I think I understand the problem now. We can make this work. |
Here is my proposed fix: #19070 |
Closing this as it's been folded into #19070. Feel free to reopen if you think the proposed fix is not enough |
Context: #17583
We killed support for tracing functions that reference tensor(s) requiring grad. This effectively makes it impossible to trace a function that computes on model parameters (e.g., compute the spectral norms of all conv weights in a CNN), unless we represent the function as another module. Yet using another module isn't always ideal, because then the two modules would need to share parameters (uhh).
An alternative workaround is to just put every parameter as the traced function's input, and then use another wrapper on top of the returned callable. This results in rather ugly and unreadable code.
So here is my FR: supporting either
nn.Module
method, ornn.Module
as an inputReally these two can maybe be viewed as the same thing because a method is just a bounded function with an input being the
nn.Module
object, i.e.,self
.Here is the proposed API:
The text was updated successfully, but these errors were encountered: