From c0e26995b329e521b0d21e506a35756c63138ad4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Aug 2024 12:08:50 +0200 Subject: [PATCH] Revert "Mamba / FalconMamba: Fix mamba left padding (#32677)" This reverts commit 91b799b3030dc29c8ee9a5323b02d4e47ee62377. --- .../models/codegen/modeling_codegen.py | 29 ++++++++++++++-- .../models/cohere/modeling_cohere.py | 15 +++++---- src/transformers/models/dbrx/modeling_dbrx.py | 15 +++++---- .../models/falcon/modeling_falcon.py | 33 +++++++++++++++++-- .../models/gemma/modeling_gemma.py | 15 +++++---- .../models/gemma2/modeling_gemma2.py | 19 +++-------- .../models/gpt_neo/modeling_gpt_neo.py | 29 ++++++++++++++-- .../models/gpt_neox/modeling_gpt_neox.py | 30 +++++++++++++++-- src/transformers/models/gptj/modeling_gptj.py | 29 ++++++++++++++-- .../models/llama/modeling_llama.py | 15 +++++---- .../models/nemotron/modeling_nemotron.py | 15 +++++---- src/transformers/models/olmo/modeling_olmo.py | 15 +++++---- .../models/persimmon/modeling_persimmon.py | 15 +++++---- src/transformers/models/phi/modeling_phi.py | 15 +++++---- src/transformers/models/phi3/modeling_phi3.py | 15 +++++---- .../models/qwen2/modeling_qwen2.py | 15 +++++---- .../models/qwen2_moe/modeling_qwen2_moe.py | 15 +++++---- .../models/stablelm/modeling_stablelm.py | 15 +++++---- .../models/starcoder2/modeling_starcoder2.py | 15 +++++---- tests/models/bart/test_modeling_bart.py | 5 +++ tests/models/bert/test_modeling_bert.py | 5 +++ tests/models/whisper/test_modeling_whisper.py | 5 +++ 22 files changed, 258 insertions(+), 121 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index a8df9ed7f3fb08..5b740f92846604 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -619,10 +619,33 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids.contiguous()} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) model_inputs.update( { diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 4bc03ade8908c3..afcea137b58bb9 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1132,17 +1132,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index f07b910fcf3ea8..3486d5ed3ab08c 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1403,17 +1403,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 37385dd9fd6f9d..78ee0fd1c124bb 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1123,10 +1123,37 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) model_inputs.update( { diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2710234717200f..a05d2c059e2164 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1143,17 +1143,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e4953a161316be..b4308f4f776758 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -301,7 +300,6 @@ def forward( attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -502,11 +500,9 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: @@ -517,7 +513,6 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -583,7 +578,6 @@ def forward( attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) if attention_mask.shape[-1] <= 1: # when decoding attention_mask = attention_mask[:, :, :, -self.sliding_window :] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -996,7 +990,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1063,7 +1056,6 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1079,10 +1071,10 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if ( isinstance(past_key_values, HybridCache) @@ -1093,12 +1085,10 @@ def prepare_inputs_for_generation( batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device - + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, @@ -1109,7 +1099,6 @@ def prepare_inputs_for_generation( cache_position=cache_position, batch_size=batch_size, ) - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index f5f8ebb2a9dbba..5bcbf0fb666821 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -815,10 +815,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) model_inputs.update( { diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 32988e88df34a8..b10c2464695d2a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -1054,10 +1054,34 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_shape) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.embed_out.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + model_inputs.update( { "attention_mask": attention_mask, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index fa658d9e05782c..c139c3b93d4b7d 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -964,10 +964,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) model_inputs.update( { diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dd053c805fb8ca..4a0887629cc62b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1265,17 +1265,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index c56a275bcd4c0b..db4bce273ca195 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1136,17 +1136,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 61a8a2bf6b60ba..1940660f61b501 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1176,17 +1176,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 885d744266368a..1e4f56c0674d20 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -993,17 +993,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c44545976316c1..6d63c0ea7e8e51 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1278,17 +1278,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 40be45e812d004..89967c8405043c 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1318,17 +1318,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f30f3bdac76e66..28b414b1901b85 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1176,17 +1176,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8cf5c200d8455f..12ebe26e058dba 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1372,17 +1372,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 2f326184e120fc..988948a9a827ba 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1271,17 +1271,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d35b1911499b7a..f3b365776ead16 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1152,17 +1152,18 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 61a0aa9091d437..dd0cb5bf4c0b9e 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1540,3 +1540,8 @@ def test_retain_grad_hidden_states_attentions(self): @unittest.skip def test_save_load_fast_init_from_base(self): pass + + @unittest.skip(reason="Generate needs input ids") + def test_inputs_embeds_matches_input_ids_with_generate(self): + # generate only works with input ids for bartforcausalLM + pass diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 6ae9f6c279de3a..766cc1c1bbf18f 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -502,6 +502,11 @@ def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + @unittest.skip(reason="Generate needs input ids") + def test_inputs_embeds_matches_input_ids_with_generate(self): + # generate only works with input ids for bertforcausalLM + pass + def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 8eb7262809c784..155588ad02c61d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -4111,6 +4111,11 @@ def test_generate_without_input_ids(self): # generate only works with input ids for whisper pass + @unittest.skip(reason="Generate needs input ids") + def test_inputs_embeds_matches_input_ids_with_generate(self): + # generate only works with input ids for whisper + pass + @unittest.skip(reason="Decoder can't keep attention grads") def test_retain_grad_hidden_states_attentions(self): return