Skip to content

Commit

Permalink
Add support for ctranslate2 whisper models (#10)
Browse files Browse the repository at this point in the history
* add ctranslate2 eval

* correct calc_rtf name for all files

* add missing regex dependency

* fix cuda index

* add DEVICE_INDEX as variable

* add hub login

* remove regex included in transformers

* fix inference

* use_auth_token -> token

* add compute type for evaluation

* scripts corrections

* fix typo

* update shell script url
  • Loading branch information
chainyo authored Sep 11, 2023
1 parent cbd8b9b commit 9a596ac
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Each library has its own set of requirements. We recommend using a clean conda e
2) Install PyTorch by following the instructions here: https://pytorch.org/get-started/locally/
3) Install the common requirements for all library by running `pip install -r requirements/requirements.txt`.
4) Install the requirements for each library you wish to evalaute by running `pip install -r requirements/requirements_<library_name>.txt`.

5) Connect your Hugging Face account by running `huggingface-cli login`.

# Evaluate a model

Expand All @@ -32,7 +32,7 @@ To add a new library for evalution in this benchmark, please follow the steps be
4) Create one bash file per model type following the convesion `run_<model_type>.sh`.
- The bash script should follow the same steps as other libraries.
- Different model sizes of the same type should share the script. For example `Wav2Vec` and `Wav2Vec2` would be two separate scripts, but different size of `Wav2Vec2` would be part of the same script.
5) (Optional) You could also add a `compute_rtf.py` script for your library to evaluate the Real Time Factor of the model.
5) (Optional) You could also add a `calc_rtf.py` script for your library to evaluate the Real Time Factor of the model.
6) Submit a PR for your changes.

# Add a new model
Expand Down
61 changes: 61 additions & 0 deletions ctranslate2/calc_rtf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import time
import librosa

from faster_whisper import WhisperModel

device = "cuda"
device_index = 0

models = [
"guillaumekln/faster-whisper-tiny.en",
"guillaumekln/faster-whisper-small.en",
"guillaumekln/faster-whisper-base.en",
"guillaumekln/faster-whisper-medium.en",
"guillaumekln/faster-whisper-large-v1",
"guillaumekln/faster-whisper-large-v2",
]

n_batches = 3
warmup_batches = 5

audio_file = "4469669.mp3"
max_len = 600 # 10 minutes


def pre_process_audio(audio_file, sr, max_len):
_, _sr = librosa.load(audio_file, sr=sr)
audio_len = int(max_len * _sr)
audio_arr = _[:audio_len]
return {"raw": audio_arr, "sampling_rate": _sr}, audio_len


audio_dict, audio_len = pre_process_audio(audio_file, 16000, max_len)

rtfs = []

for model in models[:1]:
asr_model = WhisperModel(
model_size_or_path=model,
device=device,
device_index=device_index,
compute_type="float16",
)

for i in range(3):
print(f"outer_loop -> {i}")
total_time = 0.0
for _ in range(n_batches + warmup_batches):
print(f"batch_num -> {_}")
start = time.time()
segments, _ = asr_model.transcribe(audio_dict["raw"], language="en")
_ = [segment._asdict() for segment in segments] # Iterate over segments to run inference
end = time.time()
if _ >= warmup_batches:
total_time += end - start

rtf = (total_time / n_batches) / (audio_len / 16000)
rtfs.append(rtf)

print(f"all RTFs: {model}: {rtfs}")
rtf_val = sum(rtfs) / len(rtfs)
print(f"avg. RTF: {model}: {rtf_val}")
123 changes: 123 additions & 0 deletions ctranslate2/run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Run evaluation for ctranslate2 whisper models."""""
import argparse
import os

import evaluate
from faster_whisper import WhisperModel
from tqdm import tqdm

from normalizer import data_utils

wer_metric = evaluate.load("wer")


def dataset_iterator(dataset) -> dict:
"""
Iterate over the dataset and yield a dictionary with the audio and reference text.
Args:
dataset: dataset to iterate over
Returns:
dictionary: {"audio": audio, "reference": reference}
"""
for item in dataset:
yield {**item["audio"], "reference": item["norm_text"]}


def main(args) -> None:
"""Main function to run evaluation on a dataset."""
asr_model = WhisperModel(
model_size_or_path=args.model_id,
compute_type="float16",
device="cuda",
device_index=args.device
)

dataset = data_utils.load_data(args)

if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
dataset = dataset.take(args.max_eval_samples)

dataset = data_utils.prepare_data(dataset)

predictions = []
references = []

# Run inference
for batch in tqdm(dataset_iterator(dataset), desc=f"Evaluating {args.model_id}"):
segments, _ = asr_model.transcribe(batch["array"], language="en")
outputs = [segment._asdict() for segment in segments]
predictions.extend(
data_utils.normalizer(
"".join([segment["text"] for segment in outputs])
).strip()
)
references.extend(batch["reference"][0])

# Write manifest results
manifest_path = data_utils.write_manifest(
references, predictions, args.model_id, args.dataset_path, args.dataset, args.split
)
print("Results saved at path:", os.path.abspath(manifest_path))

wer = wer_metric.compute(references=references, predictions=predictions)
wer = round(100 * wer, 2)

print("WER:", wer, "%")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier. Should be loadable with 🤗 Transformers",
)
parser.add_argument(
'--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`'
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
"can be found at `https://huggingface.co/datasets/esb/datasets`"
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
)
parser.add_argument(
"--device",
type=int,
default=-1,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--no-streaming",
dest='streaming',
action="store_false",
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
)
args = parser.parse_args()
parser.set_defaults(streaming=False)

main(args)
102 changes: 102 additions & 0 deletions ctranslate2/run_whisper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/bin/bash

export PYTHONPATH="..":$PYTHONPATH

MODEL_IDs=("guillaumekln/faster-whisper-tiny.en" "guillaumekln/faster-whisper-small.en" "guillaumekln/faster-whisper-base.en" "guillaumekln/faster-whisper-medium.en" "guillaumekln/faster-whisper-large-v1" "guillaumekln/faster-whisper-large-v2")
BATCH_SIZE=1
DEVICE_INDEX=0

num_models=${#MODEL_IDs[@]}

for (( i=0; i<${num_models}; i++ ));
do
MODEL_ID=${MODEL_IDs[$i]}

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="ami" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="earnings22" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="gigaspeech" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="librispeech" \
--split="test.clean" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="librispeech" \
--split="test.other" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="spgispeech" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="tedlium" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="voxpopuli" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

python run_eval.py \
--model_id=${MODEL_ID} \
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
--dataset="common_voice" \
--split="test" \
--device=${DEVICE_INDEX} \
--batch_size=${BATCH_SIZE} \
--max_eval_samples=-1

# Evaluate results
RUNDIR=`pwd` && \
cd ../normalizer && \
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
cd $RUNDIR

done
File renamed without changes.
2 changes: 1 addition & 1 deletion normalizer/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_data(args):
args.dataset,
split=args.split,
streaming=args.streaming,
use_auth_token=True,
token=True,
)

return dataset
Expand Down
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ evaluate
datasets
librosa
jiwer

5 changes: 5 additions & 0 deletions requirements/requirements_ctranslate2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
datasets
evaluate
faster-whisper>=0.8.0
jiwer
librosa

0 comments on commit 9a596ac

Please sign in to comment.