diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 96f91a0a43dd77..6175009e4128d2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1633,6 +1633,9 @@ def generate( ) prompt_ids = prompt_ids.tolist() decoder_start_token_id, *text_prompt_ids = prompt_ids + # Slicing the text prompt ids in a manner consistent with the OpenAI implementation + # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) + text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :] # Set the decoder_start_token_id to <|startofprev|> kwargs.update({"decoder_start_token_id": decoder_start_token_id}) @@ -1647,9 +1650,7 @@ def generate( kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids ) forced_decoder_ids = [ - # Slicing the text prompt ids in a manner consistent with the OpenAI implementation - # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - *text_prompt_ids[-self.config.max_length // 2 - 1 :], + *text_prompt_ids, generation_config.decoder_start_token_id, *[token for _rank, token in non_prompt_forced_decoder_ids], ]