From 2e688188c3bfdce378a4ce021d8aa6ef2d99df68 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 24 May 2023 03:53:28 -0700 Subject: [PATCH] Paged Optimizer + Lion Optimizer for Trainer (#23217) * Added lion and paged optimizers and made original tests pass. * Added tests for paged and lion optimizers. * Added and fixed optimizer tests. * Style and quality checks. --------- Co-authored-by: younesbelkada --- src/transformers/trainer.py | 32 +++++ src/transformers/training_args.py | 9 +- tests/trainer/test_trainer.py | 193 +++++++++++++++++++++++++++++- 3 files changed, 230 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e708de37015282..72fcd34d7ff723 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1170,6 +1170,38 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim in [ + OptimizerNames.ADAMW_BNB, + OptimizerNames.ADAMW_8BIT, + OptimizerNames.PAGED_ADAMW, + OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.LION, + OptimizerNames.LION_8BIT, + OptimizerNames.PAGED_LION, + OptimizerNames.PAGED_LION_8BIT, + ]: + try: + from bitsandbytes.optim import AdamW, Lion + + is_paged = False + optim_bits = 32 + optimizer_cls = None + additional_optim_kwargs = adam_kwargs + if "paged" in args.optim: + is_paged = True + if "8bit" in args.optim: + optim_bits = 8 + if "adam" in args.optim: + optimizer_cls = AdamW + elif "lion" in args.optim: + optimizer_cls = Lion + additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + + bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} + optimizer_kwargs.update(additional_optim_kwargs) + optimizer_kwargs.update(bnb_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") elif args.optim == OptimizerNames.ADAMW_BNB: try: from bitsandbytes.optim import Adam8bit diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e2df8dc333b042..57aca25712dea4 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -139,10 +139,17 @@ class OptimizerNames(ExplicitEnum): ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_APEX_FUSED = "adamw_apex_fused" ADAFACTOR = "adafactor" - ADAMW_BNB = "adamw_bnb_8bit" ADAMW_ANYPRECISION = "adamw_anyprecision" SGD = "sgd" ADAGRAD = "adagrad" + ADAMW_BNB = "adamw_bnb_8bit" + ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit + LION_8BIT = "lion_8bit" + LION = "lion_32bit" + PAGED_ADAMW = "paged_adamw_32bit" + PAGED_ADAMW_8BIT = "paged_adamw_8bit" + PAGED_LION = "paged_lion_32bit" + PAGED_LION_8BIT = "paged_lion_8bit" @dataclass diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 63a12635880419..95b92d5295d024 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2474,6 +2474,11 @@ def hp_name(trial): "lr": TrainingArguments.learning_rate, } + default_lion_kwargs = { + "betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2), + "lr": TrainingArguments.learning_rate, + } + default_anyprecision_kwargs = { "use_kahan_summation": False, "momentum_dtype": torch.float32, @@ -2525,11 +2530,59 @@ def hp_name(trial): optim_test_params.append( ( TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"), - bnb.optim.Adam8bit, + bnb.optim.AdamW, default_adam_kwargs, ) ) + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"), + bnb.optim.AdamW, + default_adam_kwargs, + ) + ) + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"), + bnb.optim.AdamW, + default_adam_kwargs, + ) + ) + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"), + bnb.optim.AdamW, + default_adam_kwargs, + ) + ) + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.LION, output_dir="None"), + bnb.optim.Lion, + default_lion_kwargs, + ) + ) + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"), + bnb.optim.Lion, + default_lion_kwargs, + ) + ) + + optim_test_params.append( + ( + TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"), + bnb.optim.Lion, + default_lion_kwargs, + ) + ) + if is_torchdistx_available(): import torchdistx @@ -2598,15 +2651,113 @@ def test_bnb_adam8bit(self): modules = { "bitsandbytes": mock, "bitsandbytes.optim": mock.optim, - "bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit, + "bitsandbytes.optim.AdamW": mock.optim.AdamW, } with patch.dict("sys.modules", modules): self.check_optim_and_kwargs( TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"), - mock.optim.Adam8bit, + mock.optim.AdamW, default_adam_kwargs, ) + def test_bnb_paged_adam8bit_alias(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdamW": mock.optim.AdamW, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"), + mock.optim.AdamW, + default_adam_kwargs, + ) + + def test_bnb_paged_adam(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdamW": mock.optim.AdamW, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"), + mock.optim.AdamW, + default_adam_kwargs, + ) + + def test_bnb_paged_adam8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.AdamW": mock.optim.AdamW, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"), + mock.optim.AdamW, + default_adam_kwargs, + ) + + def test_bnb_lion(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.Lion": mock.optim.Lion, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.LION, output_dir="None"), + mock.optim.Lion, + default_lion_kwargs, + ) + + def test_bnb_lion8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.Lion": mock.optim.Lion, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"), + mock.optim.Lion, + default_lion_kwargs, + ) + + def test_bnb_paged_lion8bit(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.Lion": mock.optim.Lion, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"), + mock.optim.Lion, + default_lion_kwargs, + ) + + def test_bnb_paged_lion(self): + mock = Mock() + modules = { + "bitsandbytes": mock, + "bitsandbytes.optim": mock.optim, + "bitsandbytes.optim.Lion": mock.optim.Lion, + } + with patch.dict("sys.modules", modules): + self.check_optim_and_kwargs( + TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None"), + mock.optim.Lion, + default_lion_kwargs, + ) + def test_bnb_adam8bit_no_bnb(self): args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None") @@ -2616,6 +2767,42 @@ def test_bnb_adam8bit_no_bnb(self): with self.assertRaises(ValueError): Trainer.get_optimizer_cls_and_kwargs(args) + def test_bnb_paged_adam_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if bnb is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_adam8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if bnb is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_lion_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if bnb is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + + def test_bnb_paged_lion8bit_no_bnb(self): + args = TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None") + + # Pretend that bnb does not exist, even if installed. By setting bnb to None, importing + # bnb will fail even if bnb is installed. + with patch.dict("sys.modules", {"bitsandbytes.optim": None}): + with self.assertRaises(ValueError): + Trainer.get_optimizer_cls_and_kwargs(args) + def test_anyprecision_adamw(self): # Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists. # Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the