diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d0b2237b5b1a98..c40c353a2a0e2b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4606,6 +4606,11 @@ def create_accelerator_and_postprocess(self): # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 98f35501928965..e7211bd76919ff 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -773,8 +773,11 @@ class TrainingArguments: that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. - eval_on_start(`bool`, *optional*, defaults to `False`): + eval_on_start (`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + eval_use_gather_object (`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ framework = "pt" @@ -1465,6 +1468,13 @@ class TrainingArguments: }, ) + eval_use_gather_object: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1992,6 +2002,12 @@ def __post_init__(self): FutureWarning, ) + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0." + "This is not supported and we recommend you to update your version." + ) + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dba4b97515f8ba..25f45c3e8398ab 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -132,6 +132,7 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") if is_accelerate_available(): from accelerate import Accelerator @@ -3565,6 +3566,17 @@ class TorchDtypeTrainingArguments(TrainingArguments): self.assertIn("torch_dtype", args_dict) self.assertEqual(args_dict["torch_dtype"], dtype) + @require_accelerate_version_min_0_30 + def test_eval_use_gather_object(self): + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + @require_torch @is_staging_test