Skip to content

Commit

Permalink
Merge pull request #602 from allenai/no_shard_ddp_clip
Browse files Browse the repository at this point in the history
fixed host-device sync at each clipping step
  • Loading branch information
ananyahjha93 authored May 30, 2024
2 parents 40210bb + cfbaab5 commit ae84d47
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
18 changes: 9 additions & 9 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def is_grad_norm_metric(metric_name: str) -> bool:
for group in self.param_groups:
if (max_norm_ratio := group.get("max_grad_norm_ratio")) is not None:
num_clipped = self._do_adaptive_clipping(
group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=True
group, max_norm_ratio, global_step, all_metrics, collect_param_metrics=collect_param_metrics
)
elif (max_norm := group.get("max_grad_norm")) is not None:
num_clipped = self._do_global_fixed_clipping(
group, max_norm, all_metrics, collect_param_metrics=True
group, max_norm, all_metrics, collect_param_metrics=collect_param_metrics
)
else:
# No clipping needed.
Expand All @@ -232,14 +232,14 @@ def is_grad_norm_metric(metric_name: str) -> bool:
if num_clipped is not None:
num_grads_clipped += num_clipped

if num_eligible_grads > 0:
clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
else:
clipping_rate = torch.tensor(0.0, device="cpu")
all_metrics["clipping_rate"] = clipping_rate
if collect_param_metrics:
if num_eligible_grads > 0:
clipping_rate = torch.tensor(num_grads_clipped / num_eligible_grads, device="cpu")
else:
clipping_rate = torch.tensor(0.0, device="cpu")
all_metrics["clipping_rate"] = clipping_rate

# per_param_norm, clipping_rate and total_grad_norm are computed even with collect_param_metrics set to False
# return those values
# total_grad_norm is computed at all steps, even when collect_param_metrics is set to False
return all_metrics

@torch.no_grad()
Expand Down
1 change: 1 addition & 0 deletions tests/grad_norm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _patch_config(cfg, max_norm):
cfg.optimizer.learning_rate = 1e-3
cfg.optimizer.weight_decay = 0.1
cfg.optimizer.eps = 1e-8
cfg.optimizer.metrics_log_interval = 10
cfg.scheduler.name = "constant"
cfg.scheduler.units = "steps"
cfg.scheduler.t_warmup = 100
Expand Down

0 comments on commit ae84d47

Please sign in to comment.