Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GrokAdamW optimizer #32521

Merged
merged 8 commits into from
Aug 13, 2024

Conversation

ehartford
Copy link
Contributor

What does this PR do?

Add support for GrokAdamW optimizer

This PR adds support for the GrokAdamW optimizer to the transformers library.

Changes Introduced

  • Integrated the GrokAdamW optimizer into the Trainer class.
  • Added error handling to prompt users to install the grokadamw package if not already installed.

Motivation

The GrokAdamW optimizer enhances training performance and stability for certain models, providing users with more optimization options.

Dependencies

  • grokadamw: Users need to install this package via pip install grokadamw.

Code Changes

  • trainer.py: Added a new conditional block to import and configure the GrokAdamW optimizer.
elif args.optim == "grokadamw":
    try:
        from grokadamw import GrokAdamW

        optimizer_cls = GrokAdamW
        optimizer_kwargs.update(
            {
                "alpha_init": float(optim_args.get("alpha_init", 0.98)),
                "lamb": float(optim_args.get("lamb", 2.0)),
                "gamma": float(optim_args.get("gamma", 0.1)),
                "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
                "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
            }
        )
    except ImportError:
        raise ValueError("Please install grokadamw with `pip install grokadamw`")

Testing

  • Verified the integration with a test script to ensure the optimizer works as expected.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr and @SunMarc

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Is there an associated issue to add this as a feature request?

I'll let @muellerzr and @SunMarc chip in on whether this is something we want to add. One thing to note is that this implementation is under the MIT license. @muellerzr do you know how we typically handle difference licensing for integrations like this?

As commented - tests would need to be added - the LOMO PR is a good reference here

src/transformers/trainer.py Outdated Show resolved Hide resolved
@muellerzr
Copy link
Contributor

@amyeroberts I do not, best we wait until @LysandreJik is back for that question! (As long as it's not too problematic, I don't see an issue with adding more optimizers)

@ehartford
Copy link
Contributor Author

Thanks for adding this! Is there an associated issue to add this as a feature request?

I'll let @muellerzr and @SunMarc chip in on whether this is something we want to add. One thing to note is that this implementation is under the MIT license. @muellerzr do you know how we typically handle difference licensing for integrations like this?

As commented - tests would need to be added - the LOMO PR is a good reference here

There is not an associated issue. I am hoping to add it, in order to make it more accessible to users.

I have updated the license to Apache 2.0 to accommodate your license.

I added the tests as requested.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@ehartford
Copy link
Contributor Author

What do you guys think about the grokking functions? As it is, I don't see how they can easily pass those in. And so I just use the default grokking function (which honestly 99.9% of people will do). Otherwise they would need to pass the function as a string which would then be eval'd - fugly.

But this seems a good compromise for simplicity sake and if they want custom grokking functions they can use it manually.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Contributor

@ehartford currently they'd need to pass them in via optim_args, but since many optimizers are getting complex and needing special TrainingArgument values, I'd be open to a follow-up PR which refactors things to include optim_args and optim_kwargs as part of TrainingArguments which then get passed to the optimizer init

@ehartford
Copy link
Contributor Author

Hello,
May I ask please, did you need anything else from me?
Is there another review required?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for adding and iterating on this ❤️

@amyeroberts amyeroberts merged commit 481e156 into huggingface:main Aug 13, 2024
24 checks passed
@Shas3011
Copy link

@ehartford running this i get hit with this error. can you check

Thanks for the contribution anyways

  File "/home/shas/qbFinetuning/processor/LLM/finetunerbasemodel.py", line 69, in run_full_finetuningjob
    self._train()
  File "/home/shas/qbFinetuning/processor/LLM/error_utils.py", line 82, in wrapper_exit
    return func(*args, **kwargs)
  File "/home/shas/qbFinetuning/processor/LLM/error_utils.py", line 55, in wrapper_retry
    return func(*args, **kwargs)
  File "/home/shas/qbFinetuning/processor/LLM/finetunerbasemodel.py", line 54, in _train
    return self.train(*args,**kwargs)
  File "/home/shas/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 451, in train
    output = super().train(*args, **kwargs)
  File "/home/shas/miniforge3/lib/python3.10/site-packages/transformers/trainer.py", line 1963, in train
    return inner_training_loop(
  File "/home/shas/.local/lib/python3.10/site-packages/accelerate/utils/memory.py", line 146, in decorator
    return function(batch_size, *args, **kwargs)
  File "/home/shas/miniforge3/lib/python3.10/site-packages/transformers/trainer.py", line 2366, in _inner_training_loop
    self.optimizer.step()
  File "/home/shas/.local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 130, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
  File "/home/shas/.local/lib/python3.10/site-packages/accelerate/optimizer.py", line 170, in step
    self.optimizer.step(closure)
  File "/home/shas/.local/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 130, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
  File "/home/shas/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
  File "/home/shas/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/shas/.local/lib/python3.10/site-packages/grokadamw/grokadamw.py", line 53, in step
    return self._step_impl(closure, use_amp=False)
  File "/home/shas/.local/lib/python3.10/site-packages/grokadamw/grokadamw.py", line 92, in _step_impl
    _apply_updates()
  File "/home/shas/.local/lib/python3.10/site-packages/grokadamw/grokadamw.py", line 86, in _apply_updates
    self._update_group(group, params_with_grad, grads, grokking_signal)
  File "/home/shas/.local/lib/python3.10/site-packages/grokadamw/grokadamw.py", line 132, in _update_group
    state = group['state'][p]
KeyError: 'state'
  0%|                     

@ehartford
Copy link
Contributor Author

Please update GrokAdamW
I have version 0.1.2

@Shas3011
Copy link

Thanks a ton , working now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants