From a25e037dbae4f71c5aea1d5d3f43374a1e48498e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 22 Jun 2024 11:26:06 +0000 Subject: [PATCH 1/3] enable strict signature --- src/transformers/generation/utils.py | 63 ++++++++----------- .../models/mamba/modeling_mamba.py | 2 - tests/generation/test_utils.py | 4 -- 3 files changed, 27 insertions(+), 42 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dd1719294e8f7e..e2d655a96fd870 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2642,13 +2642,12 @@ def _sample( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -2869,6 +2868,10 @@ def _beam_search( while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + # if sequential is True, split the input to batches of batch_size and run sequentially if sequential: if any( @@ -2894,24 +2897,13 @@ def _beam_search( model_inputs, split_size=batch_size, full_batch_size=batch_beam_size ) outputs_per_sub_batch = [ - self( - **inputs_per_sub_batch, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - for inputs_per_sub_batch in inputs_per_sub_batches + self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches ] outputs = stack_model_outputs(outputs_per_sub_batch) else: # Unchanged original behavior - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -3191,12 +3183,12 @@ def _group_beam_search( # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -3472,12 +3464,11 @@ def _constrained_beam_search( while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 @@ -3740,11 +3731,11 @@ def _assisted_decoding( model_inputs["num_logits_to_keep"] = candidate_length + 1 # 2.2. Run a forward pass on the candidate sequence - outputs = self( - **model_inputs, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs) # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 04430ada87a04c..aa1bec59f5cadd 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -545,7 +545,6 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -673,7 +672,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, - **kwargs, # for now we need this for generation ) -> Union[Tuple, MambaCausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f61adbbd906c37..6215bc87edf52c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -464,8 +464,6 @@ def test_greedy_generate_dict_outputs_use_cache(self): if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") - if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): - self.skipTest("Won't fix: model with non-standard dictionary output shapes") config.use_cache = True config.is_decoder = True @@ -626,8 +624,6 @@ def test_beam_search_generate_dict_outputs_use_cache(self): if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") - if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): - self.skipTest("Won't fix: model with non-standard dictionary output shapes") model = model_class(config).to(torch_device).eval() logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( From 20b49b50b4f34fed2dec484d3d9183eadcb57031 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 22 Jun 2024 11:29:46 +0000 Subject: [PATCH 2/3] this should not have been deleted --- tests/generation/test_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6215bc87edf52c..f61adbbd906c37 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -464,6 +464,8 @@ def test_greedy_generate_dict_outputs_use_cache(self): if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") + if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): + self.skipTest("Won't fix: model with non-standard dictionary output shapes") config.use_cache = True config.is_decoder = True @@ -624,6 +626,8 @@ def test_beam_search_generate_dict_outputs_use_cache(self): if not hasattr(config, "use_cache"): self.skipTest("This model doesn't support caching") + if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): + self.skipTest("Won't fix: model with non-standard dictionary output shapes") model = model_class(config).to(torch_device).eval() logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( From 13580116554c2ae029fd5404ccb6dbc2296b1bb8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 22 Jun 2024 11:35:04 +0000 Subject: [PATCH 3/3] recurrent_gemma too --- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 2a8e1c25f6382c..40032851bfdc51 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -684,7 +684,6 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, ) -> Union[Tuple, BaseModelOutputWithNoAttention]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -823,7 +822,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, - **kwargs, # for now we need this for generation ) -> Union[Tuple, CausalLMOutput]: r""" Args: