diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3b1bef6f040084..49b213cc5e92cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1297,6 +1297,43 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]: + if self.config.is_encoder_decoder: + key = "decoder_attention_mask" + else: + key = "attention_mask" + + if key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[key] + mask_extension_length = new_mask_length - mask.shape[1] + + if mask_extension_length < 0: + raise ValueError("Cannot extend attention mask to a length less than it already is") + + model_kwargs[key] = torch.cat( + [mask, mask.new_ones((mask.shape[0], mask_extension_length))], + dim=-1, + ) + + return model_kwargs + + def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + extension_length = new_length - token_type_ids.shape[1] + token_type_copies = final_token_type.repeat(1, extension_length) + model_kwargs["token_type_ids"] = torch.cat( + [model_kwargs["token_type_ids"], token_type_copies], + dim=-1, + ) + + return model_kwargs + @torch.no_grad() def generate( self, @@ -4441,47 +4478,21 @@ def assisted_decoding( # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, # we use this forward pass to also pick the subsequent logits in the original model. - # 2.1. Run a forward pass on the candidate sequence - if "past_key_values" in model_kwargs: - model_attn = torch.ones_like(candidate_input_ids) - model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] - if self.config.is_encoder_decoder: - outputs = self( - decoder_input_ids=model_input_ids, - decoder_attention_mask=model_attn, - past_key_values=model_kwargs["past_key_values"], - encoder_outputs=model_kwargs["encoder_outputs"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - outputs = self( - model_input_ids, - attention_mask=model_attn, - past_key_values=model_kwargs["past_key_values"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - if self.config.is_encoder_decoder: - outputs = self( - decoder_input_ids=candidate_input_ids, - encoder_outputs=model_kwargs["encoder_outputs"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - outputs = self( - candidate_input_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) + # 2.1. Prepare the model inputs + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]) + candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + + # 2.2. Run a forward pass on the candidate sequence + outputs = self( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) - # 2.2. Process the new logits + # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index bdafb6347755d3..649719e0eefa5d 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -483,9 +483,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = kwargs.get("position_ids", None) if past_key_values is not None: - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - input_ids = input_ids[:, [-1]] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # input_embeds have already been used and is not required anymore input_embeds = None @@ -507,7 +516,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 52dfa5e39229f8..9e7763ca23d885 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1443,7 +1443,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1934,7 +1943,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 29846b8051f867..1b0fad3f9d6546 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1282,7 +1282,16 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index f245ac155e75ca..abe2d828b28bb9 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -993,9 +993,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 867aca67e99e8c..e266b1a67b7d41 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2628,9 +2628,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index a32f3ecde76fdb..4e279f9dc059fe 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2627,7 +2627,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 7534ed17fe849a..d1c471aa8090c9 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -729,9 +729,18 @@ def forward( def prepare_inputs_for_generation( self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs ): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for inputs_ids if past is defined in kwargs + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index bdb8c52a552041..1db81905210b63 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1392,7 +1392,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1622,7 +1631,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index a1e888aec90807..129de3dd1456e3 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1359,7 +1359,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1589,7 +1598,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 2ae3ac053beab9..49b958afc2ebae 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -920,7 +920,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d12ec1724f7097..d90bb6ad8fdfd5 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -844,9 +844,18 @@ def prepare_inputs_for_generation( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 4635c061980b53..8d7d279579e3e9 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -1542,9 +1542,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 93d5aa7ee47650..172a45544bac0d 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -617,11 +617,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -631,7 +640,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 70cd4ec0597a14..cec68de07dda75 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -526,9 +526,18 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for inputs_ids if past is defined in kwargs + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 7cbaee692564b4..a521ccb39aaf0c 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -1009,9 +1009,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index c975aa40877c26..6853f5333f137c 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -843,8 +843,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -852,7 +861,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index c06d306c1a241d..da3ee8e51d3602 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1667,9 +1667,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 7ee6f4381290ae..d55155f80093bc 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1223,7 +1223,16 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index ab29322613bea3..33b9fdde739e58 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1228,7 +1228,16 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict: if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: @@ -1236,7 +1245,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 714f0351b3e4df..838e7ca2992520 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1005,11 +1005,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -1019,7 +1028,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None @@ -1038,6 +1047,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ "token_type_ids": token_type_ids, } ) + return model_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @@ -1201,11 +1211,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -1215,7 +1234,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d58e00af1dac13..be90f61e45bf1b 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -737,11 +737,23 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + if self.config.multi_query: + past_length = past_key_values[0].shape[1] + else: + past_length = past_key_values[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -751,7 +763,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6364cfc316220a..3ad49554c0ac8f 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -680,11 +680,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -694,7 +703,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b4aa4154459cf7..9391805a77b851 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -808,10 +808,21 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): input_shape = input_ids.shape + print(input_shape) + print(past_key_values[0][0].shape if past_key_values is not None else "no pkv") # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -819,7 +830,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -830,7 +841,7 @@ def prepare_inputs_for_generation( model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} - + print(position_ids.shape) model_inputs.update( { "attention_mask": attention_mask, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a93bdeaacd9d23..6b5607f235b1a6 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -785,11 +785,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -799,7 +808,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 5f193a137b00cc..54edcd30fc870d 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -912,11 +912,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -926,7 +935,7 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None return { diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 55753d5f75d9af..4afa3293ed46c5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1080,8 +1080,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1089,7 +1098,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d08ed83af07ea1..4e8aef0678367d 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2103,9 +2103,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 88e543b54b5249..6db8bbb5213b14 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1367,7 +1367,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index b4e3aac5be0b42..69de5b2e7d0e6f 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1509,7 +1509,16 @@ def prepare_inputs_for_generation( ) -> Dict: # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1740,7 +1749,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index ca6bea40337257..530c66a0c80b36 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -948,7 +948,16 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 276f94aebdbb9e..b53ad8848dd3c2 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1413,7 +1413,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1897,7 +1906,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 1c1eeff667d44f..5d0ad6e3410c8f 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -1251,9 +1251,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a55f16a23d5b52..62610ceb41ac66 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1083,8 +1083,18 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1092,7 +1102,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 0c608dbd2a93bc..d760bec9854a8e 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -605,9 +605,18 @@ def prepare_inputs_for_generation( use_cache: Optional[bool] = None, **kwargs, ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 3d03503ddd402e..0de50afe9d6d24 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1836,9 +1836,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f178a6762005e6..16766e953c8574 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1995,9 +1995,17 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) - # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 21a82f95c33383..5c1ed05249ef5c 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1572,7 +1572,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -2054,7 +2063,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index f37f64627dfad4..3701bbecef2e73 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1808,7 +1808,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index d24211f039365e..8f3f246524348d 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -981,8 +981,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 67934520fbb6d9..55856f7b06b6be 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1466,7 +1466,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1719,7 +1728,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index def82bdbaa7182..e87e9c7164ab44 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1671,7 +1671,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c09657c065f2be..a0bc5726382336 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -847,8 +847,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -856,7 +865,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 288e31a126e675..e19761803e267d 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1798,9 +1798,18 @@ def prepare_inputs_for_generation( if decoder_attention_mask is None: decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "flattened_patches": flattened_patches, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93532f4b0d8c22..3a880839236d43 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1379,7 +1379,16 @@ def prepare_inputs_for_generation( ) -> Dict[str, Any]: # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1739,7 +1748,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 47546930ebdfc1..fead8fc0cf7f42 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -1151,9 +1151,18 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 745be26ebfc97f..235bff89f8a354 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -1147,9 +1147,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 67e0fee422c4cc..6d4cc991d22ca0 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1007,9 +1007,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index ddd87fa9ce0c1e..da1cd6331bc314 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -1014,9 +1014,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 35d4be9f20e0c0..a5b1b63050b1ef 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -1560,9 +1560,18 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if input_shape_ids is not None: input_shape_ids = input_shape_ids[:, -1:] if input_pronunciation_ids is not None: diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 2c3feeda12708c..b9c36a305ff1cd 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -1178,9 +1178,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index bfd801b242719f..f9b5dec4209273 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -963,7 +963,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 48334deb377865..c4de7de09089ca 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2508,7 +2508,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6c2fe8269782b6..541db4382dd649 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1727,9 +1727,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e6d9deefa14639..9716c7ffaffa0c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1810,9 +1810,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 50829592a02e72..c0541814be466e 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -1003,7 +1003,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 8323054144f549..bd35111be16e9f 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1307,9 +1307,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index be5f50dbffa2aa..de1565fa76127b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1810,9 +1810,17 @@ def prepare_inputs_for_generation( attention_mask=None, **kwargs, ): - # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 5f8778f98dcd2d..0c769dbbb5f324 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -851,21 +851,30 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) - - if past_key_values: - input_ids = input_ids[:, -1:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 761e96a11b7344..da454b1e3331f9 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -1011,9 +1011,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 025bab3887c0c7..26e0361abdb523 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -970,9 +970,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 61002bd2772e52..28fddc2fdbd6b5 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -1118,9 +1118,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f73e3f60a5530b..8e3079f748dfcc 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2906,3 +2906,89 @@ def test_default_max_length_warning(self): model.generation_config.max_length = 10 model.generate(input_ids) self.assertEqual(len(warning_list), 0) + + def test_model_kwarg_assisted_decoding_decoder_only(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with token_type_ids + outputs_tti = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant.config.pad_token_id = tokenizer.eos_token_id + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) + + def test_model_kwarg_assisted_decoding_encoder_decoder(self): + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg that distorts the output + class FakeBart(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + + if foo: + outs["logits"][:, :, :] = 0.0 + + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + inputs["foo"] = foo + return inputs + + model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate( + input_ids, + foo=True, + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForSeq2SeqLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())