-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to skip optim steps for 0 grad params #636
Conversation
- Adds configuration field `optimizer.record_update_metrics`, which defaults to `False`, but when set to `True` will trigger AdamW to collect the step size norm and absolute max for each parameter. - Changes the behavior of the Lion optimizer to only record the update cosine similarity when `optimizer.record_update_metrics` is `True` in order to be consistent with the API.
olmo/optim.py
Outdated
# Perform step weight decay | ||
mask: Optional[torch.Tensor] = None | ||
if self._selective_updates: | ||
mask = grad != 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thought: you could instead do mask = grad != 0 if self._selective_updates else 1
, and assume the mask is always present in subsequent logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call: 1024122
olmo/optim.py
Outdated
@@ -373,9 +376,12 @@ def __init__( | |||
super().__init__(params, defaults) | |||
for group in self.param_groups: | |||
group["initial_lr"] = group["lr"] | |||
self._selective_updates = selective_updates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Like in the other PR, this could be moved into the parent class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done: e597e5f
@@ -510,16 +512,20 @@ def step(self, closure=None) -> None: | |||
class AdamW(torch.optim.AdamW, Optimizer): | |||
def __init__(self, *args, record_update_metrics: bool = False, selective_updates: bool = False, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self._record_step_size = record_update_metrics | |||
|
|||
# Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we don't call Optimizer.__init__
too? Because multiple inheritance is complicated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea this gets messy b/c our Optimizer.__init__()
also calls PyTorch's Optimizer.__init__()
, which would then get called twice here unless we didn't call torch.optim.AdamW.__init__()
(via super().__init__()
), but then we'd have to copy over all the other code that happens within torch.optim.AdamW.__init__()
.
#605 should be reviewed and merged first.
This PR adds the ability to skip optimizer updates for the parts of parameters that have 0 gradients, such as the embeddings for tokens not present in the current batch (assuming no weight tying).