Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2D][TP] Enable DDP TP integration with unit test #106583

Closed
wants to merge 4 commits into from

Conversation

fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Aug 3, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106583

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c1b598f with merge base d8ad748 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (ddp) release notes category label Aug 3, 2023
@fduwjj fduwjj added ciflow/trunk Trigger trunk jobs on your pull request module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category labels Aug 3, 2023
fduwjj added a commit that referenced this pull request Aug 3, 2023
ghstack-source-id: 82c02d6cb4119a3eb23fcff7d51740efa9c997bb
Pull Request resolved: #106583
docs/source/distributed.tensor.parallel.rst Outdated Show resolved Hide resolved
docs/source/distributed.tensor.parallel.rst Outdated Show resolved Hide resolved
docs/source/distributed.tensor.parallel.rst Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
_update_model_param(param_list) # type: ignore[arg-type]


def pre_dp_model_transform(model: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend naming the arg module instead of model as we push for distributed API composability since model generally refers to the root module, whereas module could mean a submodule.

In that case, you may also prefer the function name pre_dp_module_transform.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More broadly, I was wondering: Should we build this logic into DDP/FSDP and hide it from the user (avoiding the extra call)?

The natural follow-up question is how would users disable this logic if they are using DTensor in their own way and do not want this conversion logic done for them? In that case, our "official" TP API parallelize_module() can mark the constructed DTensors specially, and DDP/FSDP can only register this special logic if it detects such marked DTensors in their managed parameters.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FSDP, we already have the extension embedded inside FSDP already right? But this causes issues in state_dict so @fegin and @kumpera we think this would be a better UX for TP + DP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if users want to special handling we can either give user an option to register a customized handler here or user can choose not to call this API rather than embed into DDP/FSDP code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this causes issues in state_dict

Could you guys clarify what the issues with state dict are when using the extensions? (or point me to the right doc that describes this)

Copy link
Contributor

@fegin fegin Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I also prefer to make this logic in FSDP and DDP. What I suggested is to use hook to implement which is what this PR does. But I think the information should be a state of DDP instead of using _st_info and is attached to the parameter. Also registering the hooks should happen inside the CTOR of DDP.

"""

_localize_dtensor(model, None, None)
model.register_forward_pre_hook(_reconstruct_dtensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we expect the hook ordering to compose with other hook-based APIs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Composable with composable API is out of scope of this PR. But the idea is the same, For replicate, TP needs to convert tensor to DTensor before forward begins and do the reverse thing after FWD. For FSDP, the hook is very complicated, I have not thought more on that, but it follows what we are doing in the extension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give an example of what I am wondering: How would this compose with a hook-based activation checkpointing API? Should this registered hook come before AC or after AC? Are making this registration order clear to the user?

Recontruct DTensor parameters from local tensors
"""
param_list = []
for name, t in model.named_parameters():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have DDP above FSDP, then will this try to convert all FSDP parameters to DTensor as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we only convert parameters which are DTensors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is that, would this convert DTensors under an FSDP-managed module, not a DDP-managed module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, no. We hope down the road, we can merge FSDP-managed module into this API, too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that model.named_parameters() recurses into submodules, which maybe managed by FSDP. There is no check against that to stop the recursion. This means that if there is a DDP module above FSDP modules, then this will convert the FSDP-managed DTensors to local tensor too. Is that the desired behavior?

fduwjj added a commit that referenced this pull request Aug 4, 2023
ghstack-source-id: 06dced7618a6887b143210bd59bf69ad20077d16
Pull Request resolved: #106583
@fduwjj fduwjj requested a review from kumpera August 4, 2023 03:42
Copy link
Contributor

@kumpera kumpera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bunch of trivial doc fixes.

My concerns with this PR is the following:

  1. no error checking at all.

We don't fail if we run into a FSDP module, for example.
Do we support all forms of DTensors sharding?

  1. no module traversal caching
    We should cache the per-param sharding_info and use that to flatten/unflatten the DTensors.
    This would be faster and more composable since it would respect the model decisions at the time we called pre_dp_module_transform.

  2. It's not explicit about how it leaves the model outside of fwd/bwd.

This is relevant if we apply more parallelization transforms after it.

torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
torch/distributed/tensor/parallel/ddp.py Outdated Show resolved Hide resolved
@awgu
Copy link
Contributor

awgu commented Aug 4, 2023

Should we have a broader API discussion? Maybe we can land the transform function as private first (i.e. with a leading underscore)?

@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 4, 2023

@awgu agree that we want to have a broad API discussion and there is indeed an ongoing broader API discussion going on right now among @wanchaol, @fegin, @wz337, @rohan-varma and you to eventually leverage DeviceMesh for a unified solution for both [functional, composable] × [DDP, FSDP]. Since we already make 2d_fsdp api public and TP is still in prototype (no Backward compatibility guarantee), I think it's still ok to make this API public for now.

@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 4, 2023

@kumpera sure. For 1 and 2, I will send follow-up PRs to address them. For 3, I think we also need a hook for state_dict and optimizer_state_dict.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept to unblock. Agree to @awgu, we should have a broader discussion how device mesh is used for DDP and FSDP. Then we may want to move the implementation to DDP.

fduwjj added a commit that referenced this pull request Aug 16, 2023
ghstack-source-id: a64981ad1765d6f4332c84bfd456a6d089a28292
Pull Request resolved: #106583
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work!

@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 16, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 17, 2023
@facebook-github-bot facebook-github-bot deleted the gh/fduwjj/103/head branch August 20, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag release notes: distributed (ddp) release notes category release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants