Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Raise ValueError if TorchCheckpoint can't serialize model #27998

Closed

Conversation

bveeramani
Copy link
Member

@bveeramani bveeramani commented Aug 18, 2022

Depends on:

Why are these changes needed?

If a model is defined in the top-level directory, then Torch can't serialize the model. This PR adds a clearer error message.

Before:
If you're using more than one worker:

RuntimeError: Some workers returned results while others didn't. Make sure that session.report() (legacy API:train.report() and train.save_checkpoint()) are called the same number of times on all workers.

If you're using one worker:

_pickle.PicklingError: Can't pickle <class 'main.Identity'>: attribute lookup Identity on main failed

After:

ValueError: TorchCheckpoint can't serialize model of type Identity because Identity is defined in the top-level environment. To work around this error, call TorchCheckpoint.from_state_dict instead of TorchCheckpoint.from_model. Alternatively, move the definition of Identity to a different module.

Related issue number

Closes #27922

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Comment on lines +15 to +24
def test_from_model_value_error():
class StubModel(torch.nn.Module):
__module__ = "__main__"

def forward(x):
return x

model = StubModel()
with pytest.raises(ValueError):
TorchCheckpoint.from_model(model)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test makes assumptions about the implementation of from_model, but I wasn't sure how else to test it.

@richardliaw
Copy link
Contributor

richardliaw commented Aug 19, 2022

This particular catch-warn mechanism seems very niche.

Can we figure out how to get the _pickle.PicklingError: Can't pickle <class 'main.Identity'>: attribute lookup Identity on main failed error to raise even on multiple machines workers?

@bveeramani
Copy link
Member Author

bveeramani commented Aug 19, 2022

This particular catch-warn mechanism seems very niche.

Can we figure out how to get the _pickle.PicklingError: Can't pickle <class 'main.Identity'>: attribute lookup Identity on main failed error to raise even on multiple machines?

What error did you get when you ran it on multiple machines?

@richardliaw
Copy link
Contributor

Sorry typo -- meant workers.

@bveeramani bveeramani self-assigned this Aug 30, 2022
@bveeramani
Copy link
Member Author

This particular catch-warn mechanism seems very niche.

@richardliaw can you elaborate on this? The try-catch mechanism works with any number of workers, and it works with both Jupyter notebooks and Python programs.

Also, the error _pickle.PicklingError: Can't pickle <class 'main.Identity'>: attribute lookup Identity on main failed isn't actionable. It isn't clear what the user needs to do.

@matthewdeng
Copy link
Contributor

@bveeramani I think Richard's point is that we should start with generalizable problem of surfacing the correct error and not

RuntimeError: Some workers returned results while others didn't. Make sure that session.report() (legacy API:train.report() and train.save_checkpoint()) are called the same number of times on all workers.

For this particular error, it is niche but I'm wondering if this should instead be handled as part of a try/catch.

@amogkam amogkam assigned amogkam and unassigned amogkam Aug 30, 2022
@bveeramani
Copy link
Member Author

@bveeramani I think Richard's point is that we should start with generalizable problem of surfacing the correct error and not

RuntimeError: Some workers returned results while others didn't. Make sure that session.report() (legacy API:train.report() and train.save_checkpoint()) are called the same number of times on all workers.

I agree we should fix the general problem. That being said, I still think we should raise a ValueError. The "correct error" is confusing and unactionable.

For this particular error, it is niche but I'm wondering if this should instead be handled as part of a try/catch.

How could we handle this as part of a try/catch?

@bveeramani bveeramani added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Aug 30, 2022
@matthewdeng
Copy link
Contributor

If we can fix the general problem, then we have access to the original error. We then catch this particular error and raise the ValueError.

@bveeramani bveeramani removed their assignment Sep 23, 2022
@stale
Copy link

stale bot commented Sep 30, 2022

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.

  • If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@stale stale bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Sep 30, 2022
@bveeramani
Copy link
Member Author

Closing for now. Can re-open when #27922 is fixed

@bveeramani bveeramani closed this Oct 3, 2022
@stale stale bot removed the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Oct 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AIR] Correct error doesn't propagate
4 participants