-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Torch data transfer automatic conversion #20333
Conversation
Co-authored-by: matthewdeng <[email protected]>
Co-authored-by: matthewdeng <[email protected]>
save_checkpoint
automatic conversionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
python/ray/train/backend.py
Outdated
@@ -66,13 +122,14 @@ class BackendExecutor: | |||
def __init__( | |||
self, | |||
backend_config: BackendConfig, | |||
backend: Backend, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Undo this change since it doesn't make sense to have conflicting BackendConfig
and Backend
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I made Backend
into a singleton, so it's ok to instantiate any of the backends any number of times, and we don't need to pass around a single instance.
When saving a model in a checkpoint or reporting a model, the user has to manually extract out the module from DDP and move it to cpu so that it can be properly deserialized on the driver.
This PR adds functionality to automatically do the above so the user does not have to add this logic to their training script.
TODO: Possibly use this logic not just for checkpoints, but for return values as well.
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.