-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[air] Horovod: Use Torch.encode_data if torch is imported #28440
[air] Horovod: Use Torch.encode_data if torch is imported #28440
Conversation
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @krfricke, lgtm as a stopgap fix! But ultimately we should refactor the checkpoint encoding/decoding logic out of the Backends and into the framework-specific checkpoints.
Then, when saving a Torch model via TorchCheckpoint.from_model()
, the same encode/decode logic will apply regardless of if I'm using TorchTrainer or HorovodTrainer.
Made an issue to track this here: #28462
@@ -190,3 +190,19 @@ def load_torch_model( | |||
f"to be of type `torch.nn.Module`, or a model " | |||
f"state dict of type dict." | |||
) | |||
|
|||
|
|||
def contains_tensor(obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? I think if torch is installed, it should be safe to always use the TorchBackend for encoding/decoding (even if the data dict does not contain a tensor). I'm worried in the worst case contains_tensor
can lead to a lot of recursion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally didn't have it in, but thought the overhead wouldn't be as bad. But I agree, since it only concerns an internal communication channel and the intermediate objects are not exposed to the user, we can just do this always when torch is loaded. Updated the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out we do need it as torch.save
seems to silently fail if a full model is passed (and not a state dict).
I think it should be fine - a similar lookup has to be in pickling after all, and in most cases it should finish early.
This reverts commit 2a21445. Signed-off-by: Kai Fricke <[email protected]>
65e7d66
to
93913af
Compare
This reverts commit 93913af. Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke [email protected]
Why are these changes needed?
Horovod with Tune does not work out of the box for GPU checkpoints as they get deserialized on the non-GPU trainer worker, leading to errors. With this PR, we detect if torch is imported and a tensor is supplied in the Horovod backend. If so, we use the torch backend to serialize the data.
Related issue number
Closes #28439
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.