Skip to content

Commit

Permalink
fix: Whisper generate, move text_prompt_ids trim up for max_new_token…
Browse files Browse the repository at this point in the history
…s calculation (huggingface#23724)

move text_prompt_ids trimming to top
  • Loading branch information
Connor Henderson authored and gojiteji committed Jun 5, 2023
1 parent 05bb0d5 commit 30f7036
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,9 @@ def generate(
)
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})

Expand All @@ -1647,9 +1650,7 @@ def generate(
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
)
forced_decoder_ids = [
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
*text_prompt_ids[-self.config.max_length // 2 - 1 :],
*text_prompt_ids,
generation_config.decoder_start_token_id,
*[token for _rank, token in non_prompt_forced_decoder_ids],
]
Expand Down

0 comments on commit 30f7036

Please sign in to comment.