-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use correct PG when collecting metrics with HYBRID shard #551
Conversation
olmo/optim.py
Outdated
@@ -144,12 +147,12 @@ def is_grad_norm_metric(metric_name: str) -> bool: | |||
# Reduce mins. | |||
if per_param_min_metrics: | |||
all_mins = torch.cat(per_param_min_metrics).to(device) | |||
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN) | |||
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN, group=process_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the original approach I went with, but I had a few concerns:
- Rank 0 refers to the global rank, so only the 'first' group will perform the reduce. The code says it will warn for the other groups, and I vaguely remember my runs crashing instead. If it doesn't crash then I guess it's not a real problem. You can get around this by getting each group rank 0 using
process_group.get_global_rank(group, 0)
. - Using multiple process groups on the same stream without some sort of synchronization can lead to deadlocks (https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group). I don't fully understand it myself. The way torch seems to get around this is that it puts the ops on different streams AND (if dynamo compiling is off) has the streams wait on each other. For us, it may be that we have to either pass the process group to all
dist.*
calls or that we have to synchronize when we make distributed calls over different process groups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@2015aroras thanks, 6911bfb should fix point (1). I'll see what happens now about (2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be working fine on a single node test at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@2015aroras I'm going to merge and give this a try with the 70B. If I run into other issues I'll rethink this strategy.
Followup to #540. Fixes how we collect per-param optim metrics when using hybrid sharding. The process group we're using is the same process group that
FSDP
uses during hybrid sharding when reducing the grad norms, for example, so it should be the right one. See https://github.com/pytorch/pytorch/blob/cb17721899d4d6a55d66d4f7188e36c20a078231/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1149.