Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GrokAdamW optimizer #32521

Merged
merged 8 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,57 @@ trainer = trl.SFTTrainer(
trainer.train()
```

## GrokAdamW optimizer

The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with `pip install grokadamw`.

<Tip>

GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability.

</Tip>

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on the IMDB dataset using the GrokAdamW optimizer:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer

# Load the IMDB dataset
train_dataset = datasets.load_dataset('imdb', split='train')

# Define the training arguments
args = TrainingArguments(
output_dir="./test-grokadamw",
max_steps=1000,
per_device_train_batch_size=4,
optim="grokadamw",
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-5,
save_strategy="no",
run_name="grokadamw-imdb",
)

# Load the model and tokenizer
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

# Initialize the Trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)

# Train the model
trainer.train()
```

This script demonstrates how to fine-tune the `google/gemma-2b` model on the IMDB dataset using the GrokAdamW optimizer. The `TrainingArguments` are configured to use GrokAdamW, and the dataset is passed to the `Trainer` for training.

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -358,6 +359,13 @@ def require_lomo(test_case):
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)


def require_grokadamw(test_case):
"""
Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
"""
return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
Expand Down Expand Up @@ -1442,6 +1443,23 @@ def optimizer_hook(param):
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
elif args.optim == OptimizerNames.GROKADAMW:
if not is_grokadamw_available():
raise ValueError("Please install grokadamw with `pip install grokadamw`")

from grokadamw import GrokAdamW

optimizer_cls = GrokAdamW
optimizer_kwargs.update(
{
"alpha_init": float(optim_args.get("alpha_init", 0.98)),
"lamb": float(optim_args.get("lamb", 2.0)),
"gamma": float(optim_args.get("gamma", 0.1)),
"grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
}
)

else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"
GROKADAMW = "grokadamw"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -353,6 +354,10 @@ def is_lomo_available():
return _lomo_available


def is_grokadamw_available():
return _grokadamw_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
Expand Down Expand Up @@ -1366,6 +1367,28 @@ def test_adalomo(self):
# Check this works
_ = trainer.train()

@require_grokadamw
@require_torch_gpu
def test_grokadamw():
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=2e-5,
logging_steps=5,
optim="grokadamw",
max_steps=20,
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down
Loading