Skip to content

Commit

Permalink
Generate: Fix modern llm generate calls with synced_gpus (hugging…
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Oct 12, 2024
1 parent 617b212 commit 37ea040
Showing 1 changed file with 63 additions and 71 deletions.
134 changes: 63 additions & 71 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,10 @@ def prepare_inputs_for_generation(
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if inputs_embeds is not None: # Exception 1
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
Expand Down Expand Up @@ -2609,8 +2610,14 @@ def _dola_decoding(
outputs.hidden_states[candidate_premature_layer][:, -1, :]
).to(final_logits.device)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

next_token_logits = _dola_select_contrast(
candidate_premature_layers, candidate_premature_logits, final_logits
Expand Down Expand Up @@ -2652,11 +2659,6 @@ def _dola_decoding(
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
Expand Down Expand Up @@ -3016,8 +3018,14 @@ def _contrastive_search(
)
# contrastive_search main logic end

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
Expand All @@ -3027,11 +3035,6 @@ def _contrastive_search(
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
Expand Down Expand Up @@ -3168,8 +3171,14 @@ def _sample(
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3214,11 +3223,6 @@ def _sample(
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
Expand Down Expand Up @@ -3415,9 +3419,15 @@ def _beam_search(
else: # Unchanged original behavior
outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
continue

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3491,12 +3501,6 @@ def _beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
Expand Down Expand Up @@ -3670,9 +3674,15 @@ def _group_beam_search(

outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
continue

if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
Expand Down Expand Up @@ -3782,12 +3792,6 @@ def _group_beam_search(

input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
Expand Down Expand Up @@ -3948,9 +3952,15 @@ def _constrained_beam_search(

outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
continue

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -4018,11 +4028,6 @@ def _constrained_beam_search(
beam_idx = beam_outputs["next_beam_indices"]

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
Expand Down Expand Up @@ -4162,17 +4167,8 @@ def _assisted_decoding(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False
past_key_values = model_kwargs.get("past_key_values", None)
if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if past_key_values.get_seq_length() == 0:
start_from_empty_dynamic_cache = True

this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]

Expand Down Expand Up @@ -4271,63 +4267,59 @@ def _assisted_decoding(
# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
num_new_tokens=n_matches + 1,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
newly_added_length = n_matches + 1
if output_scores:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
if output_logits:
raw_logits += (next_token_logits,)

if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
added_len = new_cur_len
# set it to false for other iterations
start_from_empty_dynamic_cache = False
else:
added_len = n_matches + 1
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))

newly_added_length = new_cur_len if is_first_iteration else newly_added_length
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, added_len
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
cur_len,
added_len,
newly_added_length,
is_decoder_attention=True,
)
else:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
cur_len,
added_len,
newly_added_length,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
num_new_tokens=n_matches + 1,
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
is_first_iteration = False

if streamer is not None:
streamer.end()
Expand Down

0 comments on commit 37ea040

Please sign in to comment.