You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Users are running into issues where they are tracing a non-forward method on a module, and getting errors because the weights used in the trace are either being considered constants, or they are autograd recording tensors and the tracer refuses to handle them. The trace API already understands how to trace a forward method correctly capturing weights. It would be easy to extend the API so that in the general case you can trace multiple methods of a module to create a single ScriptModule with multiple methods.
importtorchimporttorch.nnasnnclassNet(nn.Module):
def__init__(self):
super().__init__()
self.conv=nn.Conv2d(3, 3, 3)
defforward(self, x):
returnself.conv(x)
defweighted_kernel_sum(self, weight): # I want to trace this thingreturn (weight*self.conv.weight).sum()
n=Net()
traced_forward=torch.jit.trace(n, example_forward_input)
traced_weight_kernel_sum=torch.jit.trace(n.weighted_kernel_sum, example_weight)
# current: error constants are requiring gradients, or the weights are captured as constants# proposed generic API:fully_traced=torch.jit.trace(n, { 'forward' : example_forward_input, 'weighted_kernel_sum': example_weight})
# fully_traced has both forward and weighted_kernel_sum present# syntax sugar for the old behavior:m=torch.jit.trace(n, example_input) # --> torch.jit.trace(n, {'foward': example_input})#syntax sugar for tracing a single method:m=torch.jit.trace(n.weighted_kernel_sum, example_weight# --> torch.jit.trace(n.weighted_kernel_sum.__self__, {'weighted_kernel_sum': n.weighted_kernel_sum.__name__})
The text was updated successfully, but these errors were encountered:
Adding 1.1 milestone. The issues causing this were introduced in 1.0, revealing a lot of cases where people were inadvertently (but successfully) capturing parameters as constants. Given the number of reports, we need to make sure this is fixed for the 1.1 release.
🚀 Feature
Motivation
Pitch
Users are running into issues where they are tracing a non-forward method on a module, and getting errors because the weights used in the trace are either being considered constants, or they are autograd recording tensors and the tracer refuses to handle them. The trace API already understands how to trace a forward method correctly capturing weights. It would be easy to extend the API so that in the general case you can trace multiple methods of a module to create a single ScriptModule with multiple methods.
The text was updated successfully, but these errors were encountered: