diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2225b033aa0a9e..68b8b598ec0978 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -379,9 +379,10 @@ def prepare_inputs_for_generation( # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None: # Exception 1 + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] @@ -2609,8 +2610,14 @@ def _dola_decoding( outputs.hidden_states[candidate_premature_layer][:, -1, :] ).to(final_logits.device) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue next_token_logits = _dola_select_contrast( candidate_premature_layers, candidate_premature_logits, final_logits @@ -2652,11 +2659,6 @@ def _dola_decoding( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) @@ -3016,8 +3018,14 @@ def _contrastive_search( ) # contrastive_search main logic end + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -3027,11 +3035,6 @@ def _contrastive_search( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) @@ -3168,8 +3171,14 @@ def _sample( # forward pass to get next token outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3214,11 +3223,6 @@ def _sample( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 @@ -3415,9 +3419,15 @@ def _beam_search( else: # Unchanged original behavior outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3491,12 +3501,6 @@ def _beam_search( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory @@ -3670,9 +3674,15 @@ def _group_beam_search( outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + continue if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) @@ -3782,12 +3792,6 @@ def _group_beam_search( input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory @@ -3948,9 +3952,15 @@ def _constrained_beam_search( outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -4018,11 +4028,6 @@ def _constrained_beam_search( beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration @@ -4162,17 +4167,8 @@ def _assisted_decoding( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - # This is needed if return_dict_in_generate is True - start_from_empty_dynamic_cache = False - past_key_values = model_kwargs.get("past_key_values", None) - if isinstance(past_key_values, DynamicCache) or ( - isinstance(past_key_values, EncoderDecoderCache) - and isinstance(past_key_values.self_attention_cache, DynamicCache) - ): - if past_key_values.get_seq_length() == 0: - start_from_empty_dynamic_cache = True - this_peer_finished = False + is_first_iteration = True # to preserve the same API in the output as other generation methods while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -4271,34 +4267,36 @@ def _assisted_decoding( # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + num_new_tokens=n_matches + 1, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue # Store scores, attentions and hidden_states when required # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: + newly_added_length = n_matches + 1 if output_scores: - scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + scores += tuple(new_logits[:, i, :] for i in range(newly_added_length)) if output_logits: - raw_logits += (next_token_logits,) - - if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache: - added_len = new_cur_len - # set it to false for other iterations - start_from_empty_dynamic_cache = False - else: - added_len = n_matches + 1 + raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length)) + newly_added_length = new_cur_len if is_first_iteration else newly_added_length if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, cur_len, added_len + cross_attentions, outputs.cross_attentions, cur_len, newly_added_length ) decoder_attentions = _split_model_outputs( decoder_attentions, outputs.decoder_attentions, cur_len, - added_len, + newly_added_length, is_decoder_attention=True, ) else: @@ -4306,28 +4304,22 @@ def _assisted_decoding( decoder_attentions, outputs.attentions, cur_len, - added_len, + newly_added_length, is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length ) else: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, cur_len, added_len + decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length ) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - num_new_tokens=n_matches + 1, - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + is_first_iteration = False if streamer is not None: streamer.end()