From 70d5df61079c50c770347a71b42b518e0fe4d0ff Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 16 Aug 2024 11:20:41 +0100 Subject: [PATCH] Generate: unify `LogitsWarper` and `LogitsProcessor` (#32626) --- docs/source/en/internal/generation_utils.md | 4 - docs/source/ja/internal/generation_utils.md | 3 - docs/source/zh/internal/generation_utils.md | 3 - .../generation/configuration_utils.py | 4 +- src/transformers/generation/logits_process.py | 54 ++-- src/transformers/generation/utils.py | 204 ++++---------- .../bark/generation_configuration_bark.py | 8 +- .../models/musicgen/modeling_musicgen.py | 20 +- .../modeling_musicgen_melody.py | 20 +- src/transformers/models/rag/modeling_rag.py | 2 - tests/generation/test_utils.py | 153 +++------- tests/models/biogpt/test_modeling_biogpt.py | 4 - tests/models/mamba/test_modeling_mamba.py | 4 + tests/models/mamba2/test_modeling_mamba2.py | 6 + .../models/musicgen/test_modeling_musicgen.py | 24 +- .../test_modeling_musicgen_melody.py | 24 +- .../test_modeling_recurrent_gemma.py | 4 + tests/models/whisper/test_modeling_whisper.py | 266 ++---------------- utils/check_docstrings.py | 1 + utils/check_repo.py | 1 + 20 files changed, 186 insertions(+), 623 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 3738a4cae7b249..0221622c408076 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -158,9 +158,6 @@ generation. [[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] LogitsWarper - - __call__ - [[autodoc]] MinLengthLogitsProcessor - __call__ @@ -421,4 +418,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] WatermarkDetector - __call__ - diff --git a/docs/source/ja/internal/generation_utils.md b/docs/source/ja/internal/generation_utils.md index d65067fc0bbd4c..9e3ce77995439c 100644 --- a/docs/source/ja/internal/generation_utils.md +++ b/docs/source/ja/internal/generation_utils.md @@ -157,9 +157,6 @@ generation_output[:2] [[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] LogitsWarper - - __call__ - [[autodoc]] MinLengthLogitsProcessor - __call__ diff --git a/docs/source/zh/internal/generation_utils.md b/docs/source/zh/internal/generation_utils.md index c82deecd3ddfcc..75f28c233ee02e 100644 --- a/docs/source/zh/internal/generation_utils.md +++ b/docs/source/zh/internal/generation_utils.md @@ -151,9 +151,6 @@ generation_output[:2] [[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] LogitsWarper - - __call__ - [[autodoc]] MinLengthLogitsProcessor - __call__ diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index c7e626f1a7c284..aa5e77ac681740 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -190,9 +190,9 @@ class GenerationConfig(PushToHubMixin): triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. renormalize_logits (`bool`, *optional*, defaults to `False`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the custom + Whether to renormalize the logits after applying all the logits processors (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits - are normalized but some logit processors or warpers break the normalization. + are normalized but some logit processors break the normalization. constraints (`List[Constraint]`, *optional*): Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index b226a059d106b1..7f89e239245bec 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -55,6 +55,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + def __init__(self): + logger.warning_once( + "`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` " + "instead, which has the same properties and interface." + ) + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: raise NotImplementedError( @@ -64,9 +70,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class LogitsProcessorList(list): """ - This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a - `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each - [`LogitsProcessor`] or [`LogitsWarper`] to the inputs. + This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor. + This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the + inputs. """ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: @@ -233,9 +239,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class TemperatureLogitsWarper(LogitsWarper): +class TemperatureLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means + [`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`]. @@ -408,10 +414,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class TopPLogitsWarper(LogitsWarper): +class TopPLogitsWarper(LogitsProcessor): """ - [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often - used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. + [`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. Args: top_p (`float`): @@ -475,10 +481,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class TopKLogitsWarper(LogitsWarper): +class TopKLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together - with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. + [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used + together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. Args: top_k (`int`): @@ -528,9 +534,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class MinPLogitsWarper(LogitsWarper): +class MinPLogitsWarper(LogitsProcessor): """ - [`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the + [`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter becomes more agressive in the presence of high-probability tokens, which is a sign of a confident output that we shouldn't deviate from. @@ -605,11 +611,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class TypicalLogitsWarper(LogitsWarper): +class TypicalLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose - log probability is close to the entropy of the token probability distribution. This means that the most likely - tokens may be discarded in the process. + [`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens + whose log probability is close to the entropy of the token probability distribution. This means that the most + likely tokens may be discarded in the process. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. @@ -693,9 +699,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class EpsilonLogitsWarper(LogitsWarper): +class EpsilonLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the + [`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. @@ -762,15 +768,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class EtaLogitsWarper(LogitsWarper): +class EtaLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic + [`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` - must be set to `True` for this `LogitsWarper` to work. + must be set to `True` for this `LogitsProcessor` to work. Args: @@ -1708,9 +1714,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores_processed -class LogitNormalization(LogitsProcessor, LogitsWarper): +class LogitNormalization(LogitsProcessor): r""" - [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize + [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that the scores are normalized when comparing the hypotheses. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 998288bd38dfd8..24c9e3bb183383 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -735,61 +735,6 @@ def _get_candidate_generator( ) return candidate_generator - def _get_logits_warper( - self, - generation_config: GenerationConfig, - device: str, - ) -> LogitsProcessorList: - """ - This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances - used for multinomial sampling. - """ - - # instantiate warpers list - warpers = LogitsProcessorList() - - # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a - # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) - if generation_config.num_beams > 1: - if isinstance(generation_config._eos_token_tensor, list): - min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 - elif isinstance(generation_config._eos_token_tensor, torch.Tensor): - min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 - else: - min_tokens_to_keep = 2 - else: - min_tokens_to_keep = 1 - - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if generation_config.temperature is not None and generation_config.temperature != 1.0: - warpers.append(TemperatureLogitsWarper(generation_config.temperature)) - if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.min_p is not None: - # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) - warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.typical_p is not None and generation_config.typical_p < 1.0: - warpers.append( - TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: - warpers.append( - EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: - warpers.append( - EtaLogitsWarper( - epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device - ) - ) - # `LogitNormalization` should always be the last logit processor, when present - if generation_config.renormalize_logits is True: - warpers.append(LogitNormalization()) - return warpers - def _get_logits_processor( self, generation_config: GenerationConfig, @@ -960,7 +905,58 @@ def _get_logits_processor( context_width=generation_config.watermarking_config.context_width, ) ) + + # TODO (joao): find a strategy to specify the order of the processors processors = self._merge_criteria_processor_list(processors, logits_processor) + + # Processors previously known as `LogitsWarpers`, only applied with sampling strategies + if generation_config.do_sample: + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + processors.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + processors.append( + TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + processors.append( + TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + processors.append( + MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + processors.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + processors.append( + EpsilonLogitsWarper( + epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep + ) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + processors.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) @@ -1940,22 +1936,11 @@ def generate( model_kwargs=model_kwargs, ) - # 12. prepare logits warper (if `do_sample` is `True`) - prepared_logits_warper = ( - self._get_logits_warper( - generation_config, - device=input_ids.device, - ) - if generation_config.do_sample - else None - ) - - # 13. run assisted generate + # 12. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1968,16 +1953,10 @@ def generate( raise ValueError( f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" ) - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) result = self._dola_decoding( input_ids, dola_layers=generation_config.dola_layers, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2005,14 +1984,7 @@ def generate( ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - - # 12. expand input_ids with `num_return_sequences` additional sequences per batch + # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, @@ -2020,11 +1992,10 @@ def generate( **model_kwargs, ) - # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2033,14 +2004,7 @@ def generate( ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - - # 12. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -2051,7 +2015,7 @@ def generate( max_length=generation_config.max_length, ) - # 13. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, @@ -2059,12 +2023,11 @@ def generate( **model_kwargs, ) - # 14. run beam sample + # 13. run beam sample result = self._beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2287,7 +2250,6 @@ def _dola_decoding( generation_config: GenerationConfig, synced_gpus: bool, streamer: "BaseStreamer", - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -2316,10 +2278,6 @@ def _dola_decoding( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2344,11 +2302,6 @@ def _dola_decoding( return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2436,8 +2389,7 @@ def _dola_decoding( ) # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - if do_sample: # sample - next_token_scores = logits_warper(input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -2893,7 +2845,6 @@ def _sample( generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -2916,11 +2867,6 @@ def _sample( streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in - `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2942,11 +2888,6 @@ def _sample( max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2990,8 +2931,6 @@ def _sample( # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - if do_sample: - next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3105,7 +3044,6 @@ def _beam_search( stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -3128,11 +3066,6 @@ def _beam_search( The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in - `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3154,11 +3087,6 @@ def _beam_search( return_dict_in_generate = generation_config.return_dict_in_generate sequential = generation_config.low_memory do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -3249,8 +3177,6 @@ def _beam_search( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - if do_sample: - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -3698,10 +3624,6 @@ def _constrained_beam_search( stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): @@ -3915,7 +3837,6 @@ def _assisted_decoding( input_ids: torch.LongTensor, candidate_generator: CandidateGenerator, logits_processor: LogitsProcessorList, - logits_warper: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, @@ -3937,10 +3858,6 @@ def _assisted_decoding( logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only used if sampling is active. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. @@ -3963,7 +3880,7 @@ def _assisted_decoding( `model.config.is_encoder_decoder=True`. """ # init values - do_sample = logits_warper is not None + do_sample = generation_config.do_sample output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -4047,9 +3964,6 @@ def _assisted_decoding( if len(logits_processor) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - if do_sample and len(logits_warper) > 0: - for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index b03fd6796a47a1..036c9caa83baba 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -56,9 +56,9 @@ def __init__( eos_token_id (`int`, *optional*, defaults to 10_000): The id of the *end-of-sequence* token. renormalize_logits (`bool`, *optional*, defaults to `True`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the + Whether to renormalize the logits after applying all the logits processors (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the - score logits are normalized but some logit processors or warpers break the normalization. + score logits are normalized but some logit processors break the normalization. max_new_tokens (`int`, *optional*, defaults to 768): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. output_scores (`bool`, *optional*, defaults to `False`): @@ -143,9 +143,9 @@ def __init__( Args: renormalize_logits (`bool`, *optional*, defaults to `True`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the + Whether to renormalize the logits after applying all the logits processors (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the - score logits are normalized but some logit processors or warpers break the normalization. + score logits are normalized but some logit processors break the normalization. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index b0e456db8add38..f720faac038e51 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1609,13 +1609,6 @@ def generate( ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -1623,11 +1616,10 @@ def generate( **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2649,13 +2641,6 @@ def generate( ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -2664,11 +2649,10 @@ def generate( **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index ba19e546a1dbc7..a8a8fe96098952 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1531,13 +1531,6 @@ def generate( ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -1545,11 +1538,10 @@ def generate( **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2490,13 +2482,6 @@ def generate( ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -2505,11 +2490,10 @@ def generate( **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index d2f92bfd71411f..bc375b68e947ab 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1558,7 +1558,6 @@ def extend_enc_output(tensor, num_beams=None): generation_config=generation_config, synced_gpus=False, streamer=None, - logits_warper=None, **model_kwargs, ) elif generation_config.num_beams > 1: @@ -1580,7 +1579,6 @@ def extend_enc_output(tensor, num_beams=None): stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=False, - logits_warper=None, **model_kwargs, ) else: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 17f788b26e2fe2..72da44115f5c50 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -118,26 +118,24 @@ def _get_input_ids_and_config(self, batch_size=2): return config, input_ids, attention_mask - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = { + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = { "bad_words_ids": [[1, 0]], "repetition_penalty": 1.2, "remove_invalid_values": True, } - # NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations - if forced_bos_token_id is None and forced_eos_token_id is None: - process_kwargs["no_repeat_ngram_size"] = 2 + if do_sample: + logits_processor_kwargs.update( + { + "top_k": 10, + "top_p": 0.7, + "temperature": 0.7, + } + ) - warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} - return process_kwargs, warp_kwargs + return logits_processor_kwargs - @staticmethod - def _get_beam_kwargs(num_return_sequences=1): + def _get_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -146,8 +144,7 @@ def _get_beam_kwargs(num_return_sequences=1): } return beam_kwargs - @staticmethod - def _get_diverse_beam_kwargs(num_return_sequences=1): + def _get_diverse_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -158,8 +155,7 @@ def _get_diverse_beam_kwargs(num_return_sequences=1): } return beam_kwargs - @staticmethod - def _get_constrained_beam_kwargs(num_return_sequences=1): + def _get_constrained_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -199,12 +195,7 @@ def _greedy_generate( output_hidden_states=False, return_dict_in_generate=False, ): - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -216,7 +207,7 @@ def _greedy_generate( output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -228,8 +219,6 @@ def _sample_generate( input_ids, attention_mask, num_return_sequences, - logits_warper_kwargs, - process_kwargs, output_scores=False, output_logits=False, output_attentions=False, @@ -237,6 +226,7 @@ def _sample_generate( return_dict_in_generate=False, ): torch.manual_seed(0) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -249,8 +239,7 @@ def _sample_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - **logits_warper_kwargs, - **process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -262,13 +251,13 @@ def _beam_search_generate( input_ids, attention_mask, beam_kwargs, - logits_process_kwargs, output_scores=False, output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -280,7 +269,7 @@ def _beam_search_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -292,7 +281,6 @@ def _beam_sample_generate( input_ids, attention_mask, beam_kwargs, - logits_warper_kwargs, output_scores=False, output_logits=False, output_attentions=False, @@ -300,6 +288,7 @@ def _beam_sample_generate( return_dict_in_generate=False, ): torch.manual_seed(0) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -311,7 +300,7 @@ def _beam_sample_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **beam_kwargs, - **logits_warper_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -323,13 +312,13 @@ def _group_beam_search_generate( input_ids, attention_mask, beam_kwargs, - logits_process_kwargs, output_scores=False, output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -341,7 +330,7 @@ def _group_beam_search_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -354,13 +343,13 @@ def _constrained_beam_search_generate( attention_mask, constraints, beam_kwargs, - logits_process_kwargs, output_scores=False, output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -373,7 +362,7 @@ def _constrained_beam_search_generate( return_dict_in_generate=return_dict_in_generate, constraints=constraints, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -395,12 +384,7 @@ def _contrastive_generate( "top_k": 5, } - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -412,7 +396,7 @@ def _contrastive_generate( output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, **contrastive_search_kwargs, ) @@ -495,19 +479,11 @@ def test_sample_generate(self): config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, num_return_sequences=1, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, ) if model.config.is_encoder_decoder: @@ -521,20 +497,11 @@ def test_sample_generate_dict_output(self): config.use_cache = False model = model_class(config).to(torch_device).eval() - - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, num_return_sequences=2, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -561,19 +528,12 @@ def test_beam_search_generate(self): model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) beam_kwargs = self._get_beam_kwargs() - output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: @@ -589,18 +549,12 @@ def test_beam_search_generate_dict_output(self): config.use_cache = False model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -633,12 +587,6 @@ def test_beam_search_generate_dict_outputs_use_cache(self): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - beam_kwargs = self._get_beam_kwargs() config.use_cache = True @@ -649,7 +597,6 @@ def test_beam_search_generate_dict_outputs_use_cache(self): input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -693,17 +640,13 @@ def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() - output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, ) if model.config.is_encoder_decoder: @@ -711,7 +654,13 @@ def test_beam_sample_generate(self): else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): + prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters) + # `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling + # code is up to date with our most recent standards + if ( + "inputs_embeds" in prepare_inputs_for_generation_args + and "cache_positions" in prepare_inputs_for_generation_args + ): input_embeds = model.get_input_embeddings()(input_ids) beam_kwargs.update({"inputs_embeds": input_embeds}) output_generate2 = self._beam_sample_generate( @@ -719,7 +668,6 @@ def test_beam_sample_generate(self): input_ids=None, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, ) torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) @@ -732,7 +680,6 @@ def test_beam_sample_generate_dict_output(self): config.use_cache = False model = model_class(config).to(torch_device).eval() - _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( @@ -740,7 +687,6 @@ def test_beam_sample_generate_dict_output(self): input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -788,12 +734,6 @@ def test_group_beam_search_generate(self): config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - # check `generate()` and `group_beam_search()` are equal beam_kwargs = self._get_diverse_beam_kwargs() output_generate = self._group_beam_search_generate( @@ -801,7 +741,6 @@ def test_group_beam_search_generate(self): input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -816,7 +755,6 @@ def test_group_beam_search_generate(self): input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -829,19 +767,12 @@ def test_group_beam_search_generate_dict_output(self): config.use_cache = False model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - beam_kwargs = self._get_diverse_beam_kwargs() output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -871,12 +802,6 @@ def test_constrained_beam_search_generate(self): model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - # Sample constraints min_id = 3 max_id = config.vocab_size @@ -893,7 +818,6 @@ def test_constrained_beam_search_generate(self): attention_mask=attention_mask, constraints=constraints, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: @@ -919,7 +843,6 @@ def test_constrained_beam_search_generate(self): attention_mask=attention_mask, constraints=constraints, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: @@ -938,11 +861,6 @@ def test_constrained_beam_search_generate_dict_output(self): config.use_cache = False model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) # Sample constraints min_id = 3 @@ -959,7 +877,6 @@ def test_constrained_beam_search_generate_dict_output(self): attention_mask=attention_mask, constraints=constraints, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index 1ccb2b54cc9af3..4f1d5d6a42f8a9 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -414,10 +414,6 @@ def test_biogpt_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip(reason="The `input_embeds` when fed don't produce the same results.") - def test_beam_sample_generate(self): - pass - @require_torch class BioGptModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index cd800da9765169..e7e3a7242cddc4 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -433,6 +433,10 @@ def recursive_check(tuple_object, dict_object): dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip("The `input_embeds` when fed don't produce the same results.") + def test_beam_sample_generate(self): + pass + @require_torch class MambaIntegrationTests(unittest.TestCase): diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 13cc22561fe174..276ecf2fd6b0fb 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -283,6 +283,12 @@ def recursive_check(tuple_object, dict_object): dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip( + reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)" + ) + def test_inputs_embeds_matches_input_ids_with_generate(self): + pass + @require_torch @slow diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 7fc2f8c9db47d0..870a4c92767b9b 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -293,15 +293,9 @@ def _get_input_ids_and_config(self, batch_size=2): attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) return config, input_ids, attention_mask - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: @@ -1483,15 +1477,9 @@ def _sample_generate( return output_generate - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 7cebf037d27af4..9b34f4dde6594f 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -296,15 +296,9 @@ def _get_input_ids_and_config(self, batch_size=2): attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) return config, input_ids, attention_mask - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: @@ -1467,15 +1461,9 @@ def _sample_generate( return output_generate - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index ad542db2733b9e..1a58ee2970d8eb 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -413,6 +413,10 @@ def _check_hidden_states_for_generate( def test_initialization(self): pass + @unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)") + def test_inputs_embeds_matches_input_ids_with_generate(self): + pass + @require_torch_gpu @slow diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 155588ad02c61d..6deebf552b91f5 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -68,14 +68,7 @@ set_seed, ) from transformers.generation import ( - BeamSampleDecoderOnlyOutput, - BeamSampleEncoderDecoderOutput, - BeamSearchDecoderOnlyOutput, - BeamSearchEncoderDecoderOutput, - GenerateBeamDecoderOnlyOutput, - GenerateBeamEncoderDecoderOutput, GenerateEncoderDecoderOutput, - PhrasalConstraint, ) from transformers.generation.logits_process import LogitsProcessor from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids @@ -419,6 +412,30 @@ def is_pipeline_test_to_skip( return False + def _get_logits_processor_kwargs(self, do_sample=False): + # Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search + logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample) + logits_processor_kwargs["temperature"] = 0.0 + return logits_processor_kwargs + + def _get_beam_kwargs(self, num_return_sequences=1): + # Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate` + beam_kwargs = super()._get_beam_kwargs(num_return_sequences=num_return_sequences) + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + return beam_kwargs + + def _get_diverse_beam_kwargs(self, num_return_sequences=1): + # Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate` + beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences) + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + return beam_kwargs + + def _get_constrained_beam_kwargs(self, num_return_sequences=1): + # Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate` + beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences) + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + return beam_kwargs + def setUp(self): self.model_tester = WhisperModelTester(self) self.config_tester = ConfigTester(self, config_class=WhisperConfig) @@ -1551,241 +1568,6 @@ def test_longform_generate_multi_batch(self): def test_longform_generate_multi_batch_cond_prev(self): self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) - def test_beam_sample_generate_dict_output(self): - # We overwrite test_beam_sample_generate_dict_output in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - config, input_ids, attention_mask = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - model = WhisperForConditionalGeneration(config).to(torch_device).eval() - _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) - beam_kwargs = self._get_beam_kwargs() - - # With Whisper, we can only perform a beam search if the temperature is set to 0. - logits_warper_kwargs["temperature"] = 0 - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - output_generate = self._beam_sample_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - - self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]) - - def test_beam_search_generate_dict_output(self): - # We overwrite test_beam_search_generate_dict_output in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - beam_kwargs = self._get_beam_kwargs() - - # With Whisper, we can only perform a beam search if the temperature is set to 0. - logits_process_kwargs["temperature"] = 0 - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - output_generate = self._beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] - ) - - def test_beam_search_generate_dict_outputs_use_cache(self): - # We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - for model_class in self.all_generative_model_classes: - # enable cache - config, input_ids, attention_mask = self._get_input_ids_and_config() - - if not hasattr(config, "use_cache"): - self.skipTest("This model doesn't support caching") - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - - beam_kwargs = self._get_beam_kwargs() - - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - output_generate = self._beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self._check_outputs( - output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] - ) - - def test_group_beam_search_generate_dict_output(self): - # We overwrite test_group_beam_search_generate_dict_output in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - - beam_kwargs = self._get_diverse_beam_kwargs() - - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - output_generate = self._group_beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] - ) - - def test_constrained_beam_search_generate_dict_output(self): - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - - # Sample constraints - min_id = 3 - max_id = model.config.vocab_size - force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] - constraints = [ - PhrasalConstraint(force_tokens), - ] - - beam_kwargs = self._get_constrained_beam_kwargs() - output_generate = self._constrained_beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - constraints=constraints, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"] - ) - @is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue? def test_custom_4d_attention_mask(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index f57427c4f65a3a..928bd332d27936 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -70,6 +70,7 @@ # Deprecated "InputExample", "InputFeatures", + "LogitsWarper", # Signature is *args/**kwargs "TFSequenceSummary", "TFBertTokenizer", diff --git a/utils/check_repo.py b/utils/check_repo.py index 02570e3c60c3ef..acd6662cc2fbb8 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -932,6 +932,7 @@ def find_all_documented_objects() -> List[str]: "LineByLineTextDataset", "LineByLineWithRefDataset", "LineByLineWithSOPTextDataset", + "LogitsWarper", "NerPipeline", "PretrainedBartModel", "PretrainedFSMTModel",