Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gather_use_object arguments #31514

Merged
merged 12 commits into from
Jun 28, 2024
5 changes: 5 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4606,6 +4606,11 @@ def create_accelerator_and_postprocess(self):
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics

if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
self.gather_function = functools.partial(
self.gather_function, use_gather_object=self.args.eval_use_gather_object
)
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,11 @@ class TrainingArguments:
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.

eval_on_start(`bool`, *optional*, defaults to `False`):
eval_on_start (`bool`, *optional*, defaults to `False`):
Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.

eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices.
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
"""

framework = "pt"
Expand Down Expand Up @@ -1465,6 +1468,13 @@ class TrainingArguments:
},
)

eval_use_gather_object: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
},
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down Expand Up @@ -1992,6 +2002,12 @@ def __post_init__(self):
FutureWarning,
)

if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
raise ValueError(
"--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0."
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
"This is not supported and we recommend you to update your version."
)

def __str__(self):
self_as_dict = asdict(self)

Expand Down
12 changes: 12 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@

# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available():
from accelerate import Accelerator
Expand Down Expand Up @@ -3565,6 +3566,17 @@ class TorchDtypeTrainingArguments(TrainingArguments):
self.assertIn("torch_dtype", args_dict)
self.assertEqual(args_dict["torch_dtype"], dtype)

@require_accelerate_version_min_0_30
def test_eval_use_gather_object(self):
train_dataset = RegressionDataset()
eval_dataset = RegressionDataset()
model = RegressionDictModel()
args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
_ = trainer.evaluate()
_ = trainer.predict(eval_dataset)


@require_torch
@is_staging_test
Expand Down
Loading