Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ctranslate2 whisper models #10

Merged
merged 14 commits into from
Sep 11, 2023
Merged
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Each library has its own set of requirements. We recommend using a clean conda e
2) Install PyTorch by following the instructions here: https://pytorch.org/get-started/locally/
3) Install the common requirements for all library by running `pip install -r requirements/requirements.txt`.
4) Install the requirements for each library you wish to evalaute by running `pip install -r requirements/requirements_<library_name>.txt`.

5) Connect your Hugging Face account by running `huggingface-cli login`.
chainyo marked this conversation as resolved.
Show resolved Hide resolved

# Evaluate a model

Expand All @@ -32,7 +32,7 @@ To add a new library for evalution in this benchmark, please follow the steps be
4) Create one bash file per model type following the convesion `run_<model_type>.sh`.
- The bash script should follow the same steps as other libraries.
- Different model sizes of the same type should share the script. For example `Wav2Vec` and `Wav2Vec2` would be two separate scripts, but different size of `Wav2Vec2` would be part of the same script.
5) (Optional) You could also add a `compute_rtf.py` script for your library to evaluate the Real Time Factor of the model.
5) (Optional) You could also add a `calc_rtf.py` script for your library to evaluate the Real Time Factor of the model.
6) Submit a PR for your changes.

# Add a new model
Expand Down
61 changes: 61 additions & 0 deletions ctranslate2/calc_rtf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import time
import librosa

from faster_whisper import WhisperModel

device = "cuda"
device_index = 0

models = [
"guillaumekln/faster-whisper-tiny.en",
"guillaumekln/faster-whisper-small.en",
"guillaumekln/faster-whisper-base.en",
"guillaumekln/faster-whisper-medium.en",
"guillaumekln/faster-whisper-large-v1",
"guillaumekln/faster-whisper-large-v2",
]

n_batches = 3
warmup_batches = 5

audio_file = "4469669.mp3"
max_len = 600 # 10 minutes


def pre_process_audio(audio_file, sr, max_len):
_, _sr = librosa.load(audio_file, sr=sr)
audio_len = int(max_len * _sr)
audio_arr = _[:audio_len]
return {"raw": audio_arr, "sampling_rate": _sr}, audio_len


audio_dict, audio_len = pre_process_audio(audio_file, 16000, max_len)

rtfs = []

for model in models[:1]:
asr_model = WhisperModel(
model_size_or_path=model,
device=device,
device_index=device_index,
compute_type="float16",
)

for i in range(3):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the outer loop for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. I just copied and pasted this part from the transformers folder example. 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Okay, I'll sync and fix this offline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outer loop is to compute the average of 3 runs for RTF

print(f"outer_loop -> {i}")
total_time = 0.0
for _ in range(n_batches + warmup_batches):
print(f"batch_num -> {_}")
start = time.time()
segments, _ = asr_model.transcribe(audio_dict["raw"], language="en")
_ = [segment._asdict() for segment in segments] # Iterate over segments to run inference
end = time.time()
if _ >= warmup_batches:
total_time += end - start

rtf = (total_time / n_batches) / (audio_len / 16000)
rtfs.append(rtf)

print(f"all RTFs: {model}: {rtfs}")
rtf_val = sum(rtfs) / len(rtfs)
print(f"avg. RTF: {model}: {rtf_val}")
123 changes: 123 additions & 0 deletions ctranslate2/run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Run evaluation for ctranslate2 whisper models."""""
import argparse
import os

import evaluate
from faster_whisper import WhisperModel
from tqdm import tqdm

from normalizer import data_utils

wer_metric = evaluate.load("wer")


def dataset_iterator(dataset) -> dict:
"""
Iterate over the dataset and yield a dictionary with the audio and reference text.

Args:
dataset: dataset to iterate over

Returns:
dictionary: {"audio": audio, "reference": reference}
"""
for item in dataset:
yield {**item["audio"], "reference": item["norm_text"]}


def main(args) -> None:
"""Main function to run evaluation on a dataset."""
asr_model = WhisperModel(
model_size_or_path=args.model_id,
compute_type="float16",
device="cuda",
device_index=args.device
)

dataset = data_utils.load_data(args)

if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
dataset = dataset.take(args.max_eval_samples)

dataset = data_utils.prepare_data(dataset)

predictions = []
references = []

# Run inference
for batch in tqdm(dataset_iterator(dataset), desc=f"Evaluating {args.model_id}"):
segments, _ = asr_model.transcribe(batch["array"], language="en")
outputs = [segment._asdict() for segment in segments]
predictions.extend(
data_utils.normalizer(
"".join([segment["text"] for segment in outputs])
).strip()
)
references.extend(batch["reference"][0])

# Write manifest results
manifest_path = data_utils.write_manifest(
references, predictions, args.model_id, args.dataset_path, args.dataset, args.split
)
print("Results saved at path:", os.path.abspath(manifest_path))

wer = wer_metric.compute(references=references, predictions=predictions)
wer = round(100 * wer, 2)

print("WER:", wer, "%")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier. Should be loadable with 🤗 Transformers",
)
parser.add_argument(
'--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`'
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
"can be found at `https://huggingface.co/datasets/esb/datasets`"
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
)
parser.add_argument(
"--device",
type=int,
default=-1,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--no-streaming",
dest='streaming',
action="store_false",
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
)
args = parser.parse_args()
parser.set_defaults(streaming=False)

main(args)
102 changes: 102 additions & 0 deletions ctranslate2/run_whisper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/bin/bash

export PYTHONPATH="..":$PYTHONPATH

MODEL_IDs=("guillaumekln/faster-whisper-tiny.en" "guillaumekln/faster-whisper-small.en" "guillaumekln/faster-whisper-base.en" "guillaumekln/faster-whisper-medium.en" "guillaumekln/faster-whisper-large-v1" "guillaumekln/faster-whisper-large-v2")
BATCH_SIZE=1
DEVICE_INDEX=0

num_models=${#MODEL_IDs[@]}

for (( i=0; i<${num_models}; i++ ));
do
MODEL_ID=${MODEL_IDs[$i]}

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="ami" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="earnings22" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="gigaspeech" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="librispeech" \
--split="test.clean" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="librispeech" \
--split="test.other" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="spgispeech" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="tedlium" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="voxpopuli" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="common_voice" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

# Evaluate results
RUNDIR=`pwd` && \
cd ../normalizer && \
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
cd $RUNDIR

done
File renamed without changes.
2 changes: 1 addition & 1 deletion normalizer/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_data(args):
args.dataset,
split=args.split,
streaming=args.streaming,
use_auth_token=True,
token=True,
chainyo marked this conversation as resolved.
Show resolved Hide resolved
)

return dataset
Expand Down
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ evaluate
datasets
librosa
jiwer

5 changes: 5 additions & 0 deletions requirements/requirements_ctranslate2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
datasets
evaluate
faster-whisper>=0.8.0
jiwer
librosa