Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use correct PG when collecting metrics with HYBRID shard #551

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def _clean_param_name(self, name: str) -> str:

@torch.no_grad()
def clip_grads_and_collect_metrics(
self, global_step: int, collect_param_metrics: bool = True
self,
global_step: int,
collect_param_metrics: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, torch.Tensor]:
"""
Clips gradients for every group that has the field `max_grad_norm`.
Expand Down Expand Up @@ -144,12 +147,12 @@ def is_grad_norm_metric(metric_name: str) -> bool:
# Reduce mins.
if per_param_min_metrics:
all_mins = torch.cat(per_param_min_metrics).to(device)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN, group=process_group)
Copy link
Collaborator

@2015aroras 2015aroras Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the original approach I went with, but I had a few concerns:

  1. Rank 0 refers to the global rank, so only the 'first' group will perform the reduce. The code says it will warn for the other groups, and I vaguely remember my runs crashing instead. If it doesn't crash then I guess it's not a real problem. You can get around this by getting each group rank 0 using process_group.get_global_rank(group, 0).
  2. Using multiple process groups on the same stream without some sort of synchronization can lead to deadlocks (https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group). I don't fully understand it myself. The way torch seems to get around this is that it puts the ops on different streams AND (if dynamo compiling is off) has the streams wait on each other. For us, it may be that we have to either pass the process group to all dist.* calls or that we have to synchronize when we make distributed calls over different process groups.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras thanks, 6911bfb should fix point (1). I'll see what happens now about (2).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be working fine on a single node test at least.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras I'm going to merge and give this a try with the 70B. If I run into other issues I'll rethink this strategy.

per_param_min_metrics = all_mins.split(1)
# Reduce maxs.
if per_param_max_metrics:
all_maxs = torch.cat(per_param_max_metrics).to(device)
dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX, group=process_group)
per_param_max_metrics = all_maxs.split(1)
# Reduce sums or just norms.
all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
Expand All @@ -159,13 +162,13 @@ def is_grad_norm_metric(metric_name: str) -> bool:
all_sums_norms_numels = torch.cat(
[all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM, group=process_group)
all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
# Get averages.
# NOTE: could get infs for non-rank0 processes but that's okay.
per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
else:
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM, group=process_group)
grad_norm_metric_mask = torch.tensor(
[float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
)
Expand Down
6 changes: 5 additions & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# Clip gradient norms and collect param/gradient/optim metrics.
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
optim_metrics = self.optim.clip_grads_and_collect_metrics(
self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
self.global_step,
collect_param_metrics=should_log_optim_metrics_this_step,
# passing this process group here ensures metrics are reduced correctly when we're using
# HYBRID sharding.
process_group=self.fsdp_model.process_group,
)

# Adjust the learning rate.
Expand Down
Loading