-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[2D][TP] Enable DDP TP integration with unit test #106583
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106583
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c1b598f with merge base d8ad748 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
ghstack-source-id: 82c02d6cb4119a3eb23fcff7d51740efa9c997bb Pull Request resolved: #106583
_update_model_param(param_list) # type: ignore[arg-type] | ||
|
||
|
||
def pre_dp_model_transform(model: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend naming the arg module
instead of model
as we push for distributed API composability since model
generally refers to the root module, whereas module
could mean a submodule.
In that case, you may also prefer the function name pre_dp_module_transform
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More broadly, I was wondering: Should we build this logic into DDP/FSDP and hide it from the user (avoiding the extra call)?
The natural follow-up question is how would users disable this logic if they are using DTensor in their own way and do not want this conversion logic done for them? In that case, our "official" TP API parallelize_module()
can mark the constructed DTensors specially, and DDP/FSDP can only register this special logic if it detects such marked DTensors in their managed parameters.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if users want to special handling we can either give user an option to register a customized handler here or user can choose not to call this API rather than embed into DDP/FSDP code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this causes issues in state_dict
Could you guys clarify what the issues with state dict are when using the extensions? (or point me to the right doc that describes this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I also prefer to make this logic in FSDP and DDP. What I suggested is to use hook to implement which is what this PR does. But I think the information should be a state of DDP instead of using _st_info
and is attached to the parameter. Also registering the hooks should happen inside the CTOR of DDP.
""" | ||
|
||
_localize_dtensor(model, None, None) | ||
model.register_forward_pre_hook(_reconstruct_dtensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we expect the hook ordering to compose with other hook-based APIs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Composable with composable API is out of scope of this PR. But the idea is the same, For replicate, TP needs to convert tensor to DTensor before forward begins and do the reverse thing after FWD. For FSDP, the hook is very complicated, I have not thought more on that, but it follows what we are doing in the extension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To give an example of what I am wondering: How would this compose with a hook-based activation checkpointing API? Should this registered hook come before AC or after AC? Are making this registration order clear to the user?
Recontruct DTensor parameters from local tensors | ||
""" | ||
param_list = [] | ||
for name, t in model.named_parameters(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have DDP above FSDP, then will this try to convert all FSDP parameters to DTensor as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we only convert parameters which are DTensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is that, would this convert DTensors under an FSDP-managed module, not a DDP-managed module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, no. We hope down the road, we can merge FSDP-managed module into this API, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that model.named_parameters()
recurses into submodules, which maybe managed by FSDP. There is no check against that to stop the recursion. This means that if there is a DDP module above FSDP modules, then this will convert the FSDP-managed DTensors to local tensor too. Is that the desired behavior?
[ghstack-poisoned]
ghstack-source-id: 06dced7618a6887b143210bd59bf69ad20077d16 Pull Request resolved: #106583
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bunch of trivial doc fixes.
My concerns with this PR is the following:
- no error checking at all.
We don't fail if we run into a FSDP module, for example.
Do we support all forms of DTensors sharding?
-
no module traversal caching
We should cache the per-param sharding_info and use that to flatten/unflatten the DTensors.
This would be faster and more composable since it would respect the model decisions at the time we calledpre_dp_module_transform
. -
It's not explicit about how it leaves the model outside of fwd/bwd.
This is relevant if we apply more parallelization transforms after it.
Should we have a broader API discussion? Maybe we can land the transform function as private first (i.e. with a leading underscore)? |
@awgu agree that we want to have a broad API discussion and there is indeed an ongoing broader API discussion going on right now among @wanchaol, @fegin, @wz337, @rohan-varma and you to eventually leverage DeviceMesh for a unified solution for both [functional, composable] × [DDP, FSDP]. Since we already make 2d_fsdp api public and TP is still in prototype (no Backward compatibility guarantee), I think it's still ok to make this API public for now. |
@kumpera sure. For 1 and 2, I will send follow-up PRs to address them. For 3, I think we also need a hook for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept to unblock. Agree to @awgu, we should have a broader discussion how device mesh is used for DDP and FSDP. Then we may want to move the implementation to DDP.
[ghstack-poisoned]
ghstack-source-id: a64981ad1765d6f4332c84bfd456a6d089a28292 Pull Request resolved: #106583
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice work!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #107397 Approved by: https://github.com/wanchaol ghstack dependencies: #107313, #106583
Stack from ghstack (oldest at bottom):