Skip to content

Commit

Permalink
enable average tokens across devices (#34373)
Browse files Browse the repository at this point in the history
* enable average tokens across devices

* reduce earlier in case model needs it

* simplify if statement

* reformat code to make ruff happy

* add doc for argument: average_tokens_across_devices

* cannot find world size when pytorch is unavailable

* format code

---------

Co-authored-by: Zach Mueller <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people committed Nov 5, 2024
1 parent f784d95 commit 5b36cda
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3602,7 +3602,12 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
Expand All @@ -3617,6 +3622,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,15 @@ class TrainingArguments:
},
)

average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
"synchronize num_tokens_in_batch for precise loss calculation. Reference: "
"https://github.com/huggingface/transformers/issues/34242"
},
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down Expand Up @@ -1763,6 +1772,19 @@ def __post_init__(self):
if self.framework == "pt" and is_torch_available():
self.device

# Disable average tokens when using single device
if self.average_tokens_across_devices:
try:
if self.world_size == 1:
logger.warning(
"average_tokens_across_devices is set to True but it is invalid when world size is"
"1. Turn it to False automatically."
)
self.average_tokens_across_devices = False
except ImportError as e:
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
self.average_tokens_across_devices = False

if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down

0 comments on commit 5b36cda

Please sign in to comment.