-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Whisper Conversion Script: Correct decoder_attention_heads and _download function #26834
Changes from 10 commits
ae38b19
d246584
479e8f7
a5c1cee
769da1a
df384c1
1bea0b0
adf8608
2ddddd6
dacfee7
763d56b
bdde5c4
722c57a
b1862fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | |||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -35,6 +35,85 @@ Tips: | ||||||||||||||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). | |||||||||||||||||
The original code can be found [here](https://github.com/openai/whisper). | |||||||||||||||||
|
|||||||||||||||||
## Inference | |||||||||||||||||
|
|||||||||||||||||
Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model: | |||||||||||||||||
|
|||||||||||||||||
```python | |||||||||||||||||
>>> import torchaudio | |||||||||||||||||
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration | |||||||||||||||||
|
|||||||||||||||||
>>> # Select an audio file: | |||||||||||||||||
>>> audio_path = "https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav" | |||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use a shorter audio file - this will take a long time to transcribe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to use |
|||||||||||||||||
|
|||||||||||||||||
>>> # Load the Whisper model in Hugging Face format: | |||||||||||||||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |||||||||||||||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") | |||||||||||||||||
>>> model.config.forced_decoder_ids = None | |||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need - this API is deprecated now (cc @ArthurZucker)
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't know. Removed. |
|||||||||||||||||
|
|||||||||||||||||
>>> # Select an audio file: | |||||||||||||||||
>>> waveform, sampling_rate = torchaudio.load(audio_path) | |||||||||||||||||
|
|||||||||||||||||
>>> # Use the model and processor to transcribe the audio: | |||||||||||||||||
>>> input_features = processor( | |||||||||||||||||
... waveform.squeeze().numpy(), sampling_rate=sampling_rate, return_tensors="pt" | |||||||||||||||||
... ).input_features | |||||||||||||||||
|
|||||||||||||||||
>>> # Generate token ids | |||||||||||||||||
>>> predicted_ids = model.generate(input_features) | |||||||||||||||||
|
|||||||||||||||||
>>> # Decode token ids to text | |||||||||||||||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |||||||||||||||||
|
|||||||||||||||||
>>> transcription[0] | |||||||||||||||||
' Chapter 16.' | |||||||||||||||||
``` | |||||||||||||||||
|
|||||||||||||||||
## Format Conversion | |||||||||||||||||
|
|||||||||||||||||
For users with models in the original OpenAI format who wish to utilize them with the Hugging Face library, a conversion script is provided. The example below demonstrates how to transform Whisper models from OpenAI to Hugging Face format: | |||||||||||||||||
|
|||||||||||||||||
```bash | |||||||||||||||||
# Change to the whisper directory where the script resides: | |||||||||||||||||
cd src/transformers/models/whisper/ | |||||||||||||||||
# Converts the model from OpenAI to Hugging Face format: | |||||||||||||||||
convert_openai_to_hf.py \ | |||||||||||||||||
--checkpoint_path tiny \ | |||||||||||||||||
--pytorch_dump_folder_path whisper-tiny-hf | |||||||||||||||||
``` | |||||||||||||||||
|
|||||||||||||||||
For those more comfortable working directly in Python, the conversion can also be achieved with the code snippet below: | |||||||||||||||||
|
|||||||||||||||||
```python | |||||||||||||||||
>>> from transformers.models.whisper.convert_openai_to_hf import convert_openai_whisper_to_tfms | |||||||||||||||||
>>> convert_openai_whisper_to_tfms("tiny.en", "whisper-tiny.en-hf") # doctest: +IGNORE_RESULT | |||||||||||||||||
``` | |||||||||||||||||
|
|||||||||||||||||
Now can test it by doing inference with an audio file: | |||||||||||||||||
|
|||||||||||||||||
```python | |||||||||||||||||
>>> # Load the newly converted model: | |||||||||||||||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |||||||||||||||||
>>> model = WhisperForConditionalGeneration.from_pretrained("whisper-tiny.en-hf") | |||||||||||||||||
>>> model.config.forced_decoder_ids = None | |||||||||||||||||
|
|||||||||||||||||
>>> # Select an audio file: | |||||||||||||||||
>>> waveform, sampling_rate = torchaudio.load(audio_path) | |||||||||||||||||
|
|||||||||||||||||
>>> # Use the model and processor to transcribe the audio: | |||||||||||||||||
>>> input_features = processor( | |||||||||||||||||
... waveform.squeeze().numpy(), sampling_rate=sampling_rate, return_tensors="pt" | |||||||||||||||||
... ).input_features | |||||||||||||||||
|
|||||||||||||||||
>>> # Transcribe the example: | |||||||||||||||||
>>> predicted_ids = model.generate(input_features) | |||||||||||||||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |||||||||||||||||
|
|||||||||||||||||
>>> transcription[0] | |||||||||||||||||
' Chapter 16. I might have told you of the beginning of this liaison in a few lines' | |||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We get a different result here to the one we got before ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not know why, but the transcriptions in HF and OpenAI do not match using this audio. To further explore this, I’ve created a Kaggle notebook and transcribed the audio here using Here are the transcription results:
From the results, transcriptions using the OpenAI library are consistent, even after conversion from the HF model, suggesting the conversion is accurate. However, the transcriptions using the HF library, both with your model in https://huggingface.co/openai/whisper-tiny.en, and after converting the original I ignore the reason for this. Maybe some post-processing step. I have also not tested other model sizes or audios. Should I open a new ticket with this? |
|||||||||||||||||
``` | |||||||||||||||||
|
|||||||||||||||||
This step is not usually required if we are using the models already [provided by OpenAI in the Hugging Face Hub](https://huggingface.co/openai). | |||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's just promote this directly instead and remove the notes on the conversion script: there is no need for users to have to use the conversion script since we've already converted all official Whisper checkpoints on the Hub https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK! Conversion example completely removed. |
|||||||||||||||||
|
|||||||||||||||||
## WhisperConfig | |||||||||||||||||
|
|||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,3 +1,5 @@ | ||||
#!/usr/bin/env python | ||||
"""Converts a Whisper model in OpenAI format to Hugging Face format.""" | ||||
# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved. | ||||
# | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|
@@ -14,6 +16,7 @@ | |||
|
||||
import argparse | ||||
import hashlib | ||||
import io | ||||
import os | ||||
import urllib | ||||
import warnings | ||||
|
@@ -90,7 +93,7 @@ def make_linear_from_emb(emb): | |||
return lin_layer | ||||
|
||||
|
||||
def _download(url: str, root: str) -> bytes: | ||||
def _download(url: str, root: str) -> io.BytesIO: | ||||
os.makedirs(root, exist_ok=True) | ||||
filename = os.path.basename(url) | ||||
|
||||
|
@@ -103,7 +106,7 @@ def _download(url: str, root: str) -> bytes: | |||
if os.path.isfile(download_target): | ||||
model_bytes = open(download_target, "rb").read() | ||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: | ||||
return model_bytes | ||||
return torch.load(io.BytesIO(model_bytes)) | ||||
else: | ||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") | ||||
|
||||
|
@@ -125,12 +128,13 @@ def _download(url: str, root: str) -> bytes: | |||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." | ||||
) | ||||
|
||||
return model_bytes | ||||
return torch.load(io.BytesIO(model_bytes)) | ||||
|
||||
|
||||
def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path): | ||||
if ".pt" not in checkpoint_path: | ||||
original_checkpoint = _download(_MODELS[checkpoint_path]) | ||||
root = os.path.dirname(pytorch_dump_folder_path) or "." | ||||
original_checkpoint = _download(_MODELS[checkpoint_path], root) | ||||
else: | ||||
original_checkpoint = torch.load(checkpoint_path, map_location="cpu") | ||||
dimensions = original_checkpoint["dims"] | ||||
|
@@ -151,7 +155,7 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path): | |||
encoder_layers=dimensions["n_audio_layer"], | ||||
encoder_attention_heads=dimensions["n_audio_head"], | ||||
decoder_layers=dimensions["n_text_layer"], | ||||
decoder_attention_heads=dimensions["n_text_state"], | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Surprised we managed to convert the original checkpoints with this bug @ArthurZucker 🤔 The state dicts surely won't have matched? Maybe we hardcoded this before? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but I think I hardcoded the values when converting and then later on made it automatic. I checked by actually re-running the script and seeing that this was a nice type 🤣 but good sign that no one else tried to convert the checkpoints ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about the history behind this. But looking at the
Then it was deleted and recovered in: #20600 That's where the problem seems to come from. So the original script you used may have worked properly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice digging! Yep I think I uploaded an old version late by a few commits |
||||
decoder_attention_heads=dimensions["n_text_head"], | ||||
max_source_positions=dimensions["n_audio_ctx"], | ||||
) | ||||
|
||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid using extra dependencies like
torchaudio
in our code snippets - could you maybe refactor this to only use dependencies in the Transformers library?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to use
datasets.load_dataset
.