-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Raise ValueError
if TorchCheckpoint
can't serialize model
#27998
Conversation
def test_from_model_value_error(): | ||
class StubModel(torch.nn.Module): | ||
__module__ = "__main__" | ||
|
||
def forward(x): | ||
return x | ||
|
||
model = StubModel() | ||
with pytest.raises(ValueError): | ||
TorchCheckpoint.from_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test makes assumptions about the implementation of from_model
, but I wasn't sure how else to test it.
This particular catch-warn mechanism seems very niche. Can we figure out how to get the |
What error did you get when you ran it on multiple machines? |
Sorry typo -- meant workers. |
@richardliaw can you elaborate on this? The try-catch mechanism works with any number of workers, and it works with both Jupyter notebooks and Python programs. Also, the error |
@bveeramani I think Richard's point is that we should start with generalizable problem of surfacing the correct error and not
For this particular error, it is niche but I'm wondering if this should instead be handled as part of a try/catch. |
I agree we should fix the general problem. That being said, I still think we should raise a
How could we handle this as part of a try/catch? |
If we can fix the general problem, then we have access to the original error. We then catch this particular error and raise the |
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
Closing for now. Can re-open when #27922 is fixed |
Depends on:
TorchCheckpoint.from_state_dict
#27970Why are these changes needed?
If a model is defined in the top-level directory, then Torch can't serialize the model. This PR adds a clearer error message.
Before:
If you're using more than one worker:
If you're using one worker:
After:
Related issue number
Closes #27922
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.