From c10ea409e8a666fad52309a12cd3f4b7e7ec2dff Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Thu, 20 Jun 2024 07:35:20 +0000 Subject: [PATCH 01/12] add gather_use_object arguments --- src/transformers/trainer.py | 6 +++--- src/transformers/training_args.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d0b2237b5b1a98..d0bfbfc775dccc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3837,19 +3837,19 @@ def evaluation_loop( all_losses.add(losses) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) - inputs_decode = self.gather_function((inputs_decode)) + inputs_decode = self.gather_function((inputs_decode), use_gather_object=self.args.use_gather_object) if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) - logits = self.gather_function((logits)) + logits = self.gather_function((logits), use_gather_object=self.args.use_gather_object) if not self.args.batch_eval_metrics or description == "Prediction": all_preds.add(logits) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) - labels = self.gather_function((labels)) + labels = self.gather_function((labels), use_gather_object=self.args.use_gather_object) if not self.args.batch_eval_metrics or description == "Prediction": all_labels.add(labels) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 98f35501928965..54cc7d5bbc458a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -775,6 +775,9 @@ class TrainingArguments: eval_on_start(`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + use_gather_object(`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ framework = "pt" @@ -1465,6 +1468,13 @@ class TrainingArguments: }, ) + use_gather_object: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: From 32c94aefbb96fdced152874fe0ee423bc7e1aa5c Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Thu, 20 Jun 2024 08:25:02 +0000 Subject: [PATCH 02/12] fix name and pass the CI test for Seq2SeqTrainer --- src/transformers/trainer.py | 32 ++++++++++++++++++++++++------- src/transformers/training_args.py | 4 ++-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d0bfbfc775dccc..2276e5dcc2c2ea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3831,25 +3831,45 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() + use_gather_object_in_gather_function = ( + "use_gather_object" in inspect.signature(self.gather_function).parameters.keys() + ) + # Update containers if losses is not None: - losses = self.gather_function((losses.repeat(batch_size))) + if use_gather_object_in_gather_function: + losses = self.gather_function( + (losses.repeat(batch_size)), use_gather_object=self.args.eval_use_gather_object + ) + else: + losses = self.gather_function((losses.repeat(batch_size))) all_losses.add(losses) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) - inputs_decode = self.gather_function((inputs_decode), use_gather_object=self.args.use_gather_object) + if use_gather_object_in_gather_function: + inputs_decode = self.gather_function( + (inputs_decode), use_gather_object=self.args.eval_use_gather_object + ) + else: + inputs_decode = self.gather_function((inputs_decode)) if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) - logits = self.gather_function((logits), use_gather_object=self.args.use_gather_object) + if use_gather_object_in_gather_function: + logits = self.gather_function((logits), use_gather_object=self.args.eval_use_gather_object) + else: + logits = self.gather_function((logits)) if not self.args.batch_eval_metrics or description == "Prediction": all_preds.add(logits) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) - labels = self.gather_function((labels), use_gather_object=self.args.use_gather_object) + if use_gather_object_in_gather_function: + labels = self.gather_function((labels), use_gather_object=self.args.eval_use_gather_object) + else: + labels = self.gather_function((labels)) if not self.args.batch_eval_metrics or description == "Prediction": all_labels.add(labels) @@ -4669,6 +4689,4 @@ def _fsdp_qlora_plugin_updates(self): and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point and version.parse(accelerate_version) > version.parse("0.27.0") ): - fsdp_plugin.set_mixed_precision( - self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True - ) + fsdp_plugin.set_mixed_precisi \ No newline at end of file diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 54cc7d5bbc458a..525f64d94901c3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -776,7 +776,7 @@ class TrainingArguments: eval_on_start(`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. - use_gather_object(`bool`, *optional*, defaults to `False`): + eval_use_gather_object(`bool`, *optional*, defaults to `False`): Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ @@ -1468,7 +1468,7 @@ class TrainingArguments: }, ) - use_gather_object: Optional[bool] = field( + eval_use_gather_object: Optional[bool] = field( default=False, metadata={ "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." From 2ad148d3b46583d95ec79804ec1615c6bc30f087 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Thu, 20 Jun 2024 08:27:27 +0000 Subject: [PATCH 03/12] make style --- src/transformers/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2276e5dcc2c2ea..7b1e4d148feec3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4689,4 +4689,6 @@ def _fsdp_qlora_plugin_updates(self): and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point and version.parse(accelerate_version) > version.parse("0.27.0") ): - fsdp_plugin.set_mixed_precisi \ No newline at end of file + fsdp_plugin.set_mixed_precision( + self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True + ) From fef38b481e42275a701496ac8ea8c156b195ee16 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Fri, 21 Jun 2024 00:44:59 +0000 Subject: [PATCH 04/12] make it to functools --- src/transformers/trainer.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7b1e4d148feec3..3900d7204550c5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3831,45 +3831,30 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - use_gather_object_in_gather_function = ( - "use_gather_object" in inspect.signature(self.gather_function).parameters.keys() - ) + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_objects=self.args.eval_use_gather_object + ) # Update containers if losses is not None: - if use_gather_object_in_gather_function: - losses = self.gather_function( - (losses.repeat(batch_size)), use_gather_object=self.args.eval_use_gather_object - ) - else: - losses = self.gather_function((losses.repeat(batch_size))) + losses = self.gather_function((losses.repeat(batch_size))) all_losses.add(losses) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) - if use_gather_object_in_gather_function: - inputs_decode = self.gather_function( - (inputs_decode), use_gather_object=self.args.eval_use_gather_object - ) - else: - inputs_decode = self.gather_function((inputs_decode)) + inputs_decode = self.gather_function((inputs_decode)) if not self.args.batch_eval_metrics or description == "Prediction": all_inputs.add(inputs_decode) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) - if use_gather_object_in_gather_function: - logits = self.gather_function((logits), use_gather_object=self.args.eval_use_gather_object) - else: - logits = self.gather_function((logits)) + logits = self.gather_function((logits)) if not self.args.batch_eval_metrics or description == "Prediction": all_preds.add(logits) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) - if use_gather_object_in_gather_function: - labels = self.gather_function((labels), use_gather_object=self.args.eval_use_gather_object) - else: - labels = self.gather_function((labels)) + labels = self.gather_function((labels)) if not self.args.batch_eval_metrics or description == "Prediction": all_labels.add(labels) From 5ec137d31775b6e5f75b7cc382cd99f295a69c13 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Fri, 21 Jun 2024 00:55:26 +0000 Subject: [PATCH 05/12] fix typo --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3900d7204550c5..57278ea73491e3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3833,7 +3833,7 @@ def evaluation_loop( if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): self.gather_function = functools.partial( - self.gather_function, use_gather_objects=self.args.eval_use_gather_object + self.gather_function, use_gather_object=self.args.eval_use_gather_object ) # Update containers From affaa01fea3feef6b23eae8bee8778a8b0f98404 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Fri, 21 Jun 2024 19:06:57 +0900 Subject: [PATCH 06/12] add accelerate version: --- src/transformers/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 57278ea73491e3..ef2dfc0f1e7ba3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3831,7 +3831,9 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + if "use_gather_object" in inspect.signature( + self.gather_function + ).parameters.keys() and is_accelerate_available("0.30.0"): self.gather_function = functools.partial( self.gather_function, use_gather_object=self.args.eval_use_gather_object ) From f44dabf43a1d1ab244a1f1c87cf80c367563d072 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Tue, 25 Jun 2024 02:59:13 +0000 Subject: [PATCH 07/12] adding warning --- src/transformers/trainer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ef2dfc0f1e7ba3..871470f056c517 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3831,12 +3831,18 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - if "use_gather_object" in inspect.signature( - self.gather_function - ).parameters.keys() and is_accelerate_available("0.30.0"): - self.gather_function = functools.partial( - self.gather_function, use_gather_object=self.args.eval_use_gather_object - ) + if self.args.eval_use_gather_object: + if "use_gather_object" in inspect.signature( + self.gather_function + ).parameters.keys() and is_accelerate_available("0.30.0"): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + else: + logger.warning( + "You are using eval_use_gather_object with a version of `accelerate` < 0.30.0." + "It is recommended to update your version. Since use_gather_object might not be in gather function." + ) # Update containers if losses is not None: From 8b8589afe59485f7f768a6ec21ef03b626e36290 Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Wed, 26 Jun 2024 23:13:01 +0900 Subject: [PATCH 08/12] Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 871470f056c517..b054c8cc974b68 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3831,18 +3831,13 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - if self.args.eval_use_gather_object: - if "use_gather_object" in inspect.signature( - self.gather_function - ).parameters.keys() and is_accelerate_available("0.30.0"): - self.gather_function = functools.partial( - self.gather_function, use_gather_object=self.args.eval_use_gather_object - ) - else: - logger.warning( - "You are using eval_use_gather_object with a version of `accelerate` < 0.30.0." - "It is recommended to update your version. Since use_gather_object might not be in gather function." - ) + if self.args.eval_use_gather_object and not is_accelerate_available("0.30.0"): + logger.warning("You are using eval_use_gather_object = True with a version of `accelerate` < 0.30.0. This is not supported" + " and we recommend you to update your version.") + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) # Update containers if losses is not None: From 63f32b0f32a3b5d06afac817047900d175b227a1 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Wed, 26 Jun 2024 23:15:43 +0900 Subject: [PATCH 09/12] make style --- src/transformers/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b054c8cc974b68..1938f1b98ff3e5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3832,12 +3832,14 @@ def evaluation_loop( xm.mark_step() if self.args.eval_use_gather_object and not is_accelerate_available("0.30.0"): - logger.warning("You are using eval_use_gather_object = True with a version of `accelerate` < 0.30.0. This is not supported" - " and we recommend you to update your version.") + logger.warning( + "You are using eval_use_gather_object = True with a version of `accelerate` < 0.30.0. This is not supported" + " and we recommend you to update your version." + ) if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): self.gather_function = functools.partial( - self.gather_function, use_gather_object=self.args.eval_use_gather_object - ) + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) # Update containers if losses is not None: From 621148fd1996485d4d2ffa9f93de6c640f4c8861 Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:11:02 +0900 Subject: [PATCH 10/12] Update src/transformers/training_args.py --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 525f64d94901c3..50c862f791a1ee 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -773,10 +773,10 @@ class TrainingArguments: that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. - eval_on_start(`bool`, *optional*, defaults to `False`): + eval_on_start (`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. - eval_use_gather_object(`bool`, *optional*, defaults to `False`): + eval_use_gather_object (`bool`, *optional*, defaults to `False`): Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ From b47b0bd3300f2ef27e3b4ffed9b5ff9e21ba2c7e Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Thu, 27 Jun 2024 00:18:28 +0000 Subject: [PATCH 11/12] check function move to initial part --- src/transformers/trainer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1938f1b98ff3e5..d68b410f1f1d41 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3831,16 +3831,6 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - if self.args.eval_use_gather_object and not is_accelerate_available("0.30.0"): - logger.warning( - "You are using eval_use_gather_object = True with a version of `accelerate` < 0.30.0. This is not supported" - " and we recommend you to update your version." - ) - if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): - self.gather_function = functools.partial( - self.gather_function, use_gather_object=self.args.eval_use_gather_object - ) - # Update containers if losses is not None: losses = self.gather_function((losses.repeat(batch_size))) @@ -4616,6 +4606,16 @@ def create_accelerator_and_postprocess(self): # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if self.args.eval_use_gather_object and not is_accelerate_available("0.30.0"): + logger.warning( + "You are using eval_use_gather_object = True with a version of `accelerate` < 0.30.0. This is not supported" + " and we recommend you to update your version." + ) + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None From 68f04d82f0ffe62df85fb7a40326c447b2eee01d Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Thu, 27 Jun 2024 21:09:23 +0900 Subject: [PATCH 12/12] add test for eval_use_gather_object --- src/transformers/trainer.py | 5 ----- src/transformers/training_args.py | 6 ++++++ tests/trainer/test_trainer.py | 12 ++++++++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d68b410f1f1d41..c40c353a2a0e2b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4606,11 +4606,6 @@ def create_accelerator_and_postprocess(self): # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics - if self.args.eval_use_gather_object and not is_accelerate_available("0.30.0"): - logger.warning( - "You are using eval_use_gather_object = True with a version of `accelerate` < 0.30.0. This is not supported" - " and we recommend you to update your version." - ) if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): self.gather_function = functools.partial( self.gather_function, use_gather_object=self.args.eval_use_gather_object diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 50c862f791a1ee..e7211bd76919ff 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2002,6 +2002,12 @@ def __post_init__(self): FutureWarning, ) + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0." + "This is not supported and we recommend you to update your version." + ) + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dba4b97515f8ba..25f45c3e8398ab 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -132,6 +132,7 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") if is_accelerate_available(): from accelerate import Accelerator @@ -3565,6 +3566,17 @@ class TorchDtypeTrainingArguments(TrainingArguments): self.assertIn("torch_dtype", args_dict) self.assertEqual(args_dict["torch_dtype"], dtype) + @require_accelerate_version_min_0_30 + def test_eval_use_gather_object(self): + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + @require_torch @is_staging_test