Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use correct PG when collecting metrics with HYBRID shard #551

Merged
merged 3 commits into from
Apr 19, 2024

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Apr 18, 2024

Followup to #540. Fixes how we collect per-param optim metrics when using hybrid sharding. The process group we're using is the same process group that FSDP uses during hybrid sharding when reducing the grad norms, for example, so it should be the right one. See https://github.com/pytorch/pytorch/blob/cb17721899d4d6a55d66d4f7188e36c20a078231/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1149.

@epwalsh epwalsh requested a review from 2015aroras April 18, 2024 17:37
olmo/optim.py Outdated
@@ -144,12 +147,12 @@ def is_grad_norm_metric(metric_name: str) -> bool:
# Reduce mins.
if per_param_min_metrics:
all_mins = torch.cat(per_param_min_metrics).to(device)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN, group=process_group)
Copy link
Collaborator

@2015aroras 2015aroras Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the original approach I went with, but I had a few concerns:

  1. Rank 0 refers to the global rank, so only the 'first' group will perform the reduce. The code says it will warn for the other groups, and I vaguely remember my runs crashing instead. If it doesn't crash then I guess it's not a real problem. You can get around this by getting each group rank 0 using process_group.get_global_rank(group, 0).
  2. Using multiple process groups on the same stream without some sort of synchronization can lead to deadlocks (https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group). I don't fully understand it myself. The way torch seems to get around this is that it puts the ops on different streams AND (if dynamo compiling is off) has the streams wait on each other. For us, it may be that we have to either pass the process group to all dist.* calls or that we have to synchronize when we make distributed calls over different process groups.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras thanks, 6911bfb should fix point (1). I'll see what happens now about (2).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be working fine on a single node test at least.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras I'm going to merge and give this a try with the 70B. If I run into other issues I'll rethink this strategy.

@epwalsh epwalsh merged commit 7be71cd into main Apr 19, 2024
9 of 11 checks passed
@epwalsh epwalsh deleted the epwalsh/hybrid-shard branch April 19, 2024 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants