Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to skip optim steps for 0 grad params #636

Merged
merged 13 commits into from
Jul 9, 2024
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Jun 28, 2024

#605 should be reviewed and merged first.

This PR adds the ability to skip optimizer updates for the parts of parameters that have 0 gradients, such as the embeddings for tokens not present in the current batch (assuming no weight tying).

epwalsh and others added 9 commits May 28, 2024 13:53
- Adds configuration field `optimizer.record_update_metrics`, which
  defaults to `False`, but when set to `True` will trigger AdamW to
  collect the step size norm and absolute max for each parameter.
- Changes the behavior of the Lion optimizer to only record the update cosine
  similarity when `optimizer.record_update_metrics` is `True` in order to be
  consistent with the API.
olmo/optim.py Outdated
# Perform step weight decay
mask: Optional[torch.Tensor] = None
if self._selective_updates:
mask = grad != 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: you could instead do mask = grad != 0 if self._selective_updates else 1, and assume the mask is always present in subsequent logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call: 1024122

olmo/optim.py Outdated
@@ -373,9 +376,12 @@ def __init__(
super().__init__(params, defaults)
for group in self.param_groups:
group["initial_lr"] = group["lr"]
self._selective_updates = selective_updates
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Like in the other PR, this could be moved into the parent class

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done: e597e5f

@@ -510,16 +512,20 @@ def step(self, closure=None) -> None:
class AdamW(torch.optim.AdamW, Optimizer):
def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._record_step_size = record_update_metrics

# Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we don't call Optimizer.__init__ too? Because multiple inheritance is complicated?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this gets messy b/c our Optimizer.__init__() also calls PyTorch's Optimizer.__init__(), which would then get called twice here unless we didn't call torch.optim.AdamW.__init__() (via super().__init__()), but then we'd have to copy over all the other code that happens within torch.optim.AdamW.__init__().

@epwalsh epwalsh merged commit bc60b8a into main Jul 9, 2024
12 checks passed
@epwalsh epwalsh deleted the epwalsh/selective-wd branch July 9, 2024 17:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants