Skip to content

Commit

Permalink
Add support for GrokAdamW optimizer (#32521)
Browse files Browse the repository at this point in the history
* add grokadamw

* reformat

* code review feedback, unit test

* reformat

* reformat
  • Loading branch information
ehartford committed Aug 13, 2024
1 parent b5016d5 commit 481e156
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 0 deletions.
51 changes: 51 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,57 @@ trainer = trl.SFTTrainer(
trainer.train()
```

## GrokAdamW optimizer

The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.

<Tip>

GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability.

</Tip>

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer

# Load the IMDB dataset
train_dataset = datasets.load_dataset('imdb', split='train')

# Define the training arguments
args = TrainingArguments(
output_dir="./test-grokadamw",
max_steps=1000,
per_device_train_batch_size=4,
optim="grokadamw",
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-5,
save_strategy="no",
run_name="grokadamw-imdb",
)

# Load the model and tokenizer
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

# Initialize the Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)

# Train the model
trainer.train()
```

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -358,6 +359,13 @@ def require_lomo(test_case):
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)


def require_grokadamw(test_case):
"""
Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
"""
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
Expand Down Expand Up @@ -1442,6 +1443,23 @@ def optimizer_hook(param):
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
elif args.optim == OptimizerNames.GROKADAMW:
if not is_grokadamw_available():
raise ValueError("Please install grokadamw with `pip install grokadamw`")

from grokadamw import GrokAdamW

optimizer_cls = GrokAdamW
optimizer_kwargs.update(
{
"alpha_init": float(optim_args.get("alpha_init", 0.98)),
"lamb": float(optim_args.get("lamb", 2.0)),
"gamma": float(optim_args.get("gamma", 0.1)),
"grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
}
)

else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -353,6 +354,10 @@ def is_lomo_available():
return _lomo_available


def is_grokadamw_available():
return _grokadamw_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
Expand Down Expand Up @@ -1366,6 +1367,28 @@ def test_adalomo(self):
# Check this works
_ = trainer.train()

@require_grokadamw
@require_torch_gpu
def test_grokadamw():
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=2e-5,
logging_steps=5,
optim="grokadamw",
max_steps=20,
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down

0 comments on commit 481e156

Please sign in to comment.