-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for
ctranslate2
whisper models (#10)
* add ctranslate2 eval * correct calc_rtf name for all files * add missing regex dependency * fix cuda index * add DEVICE_INDEX as variable * add hub login * remove regex included in transformers * fix inference * use_auth_token -> token * add compute type for evaluation * scripts corrections * fix typo * update shell script url
- Loading branch information
Showing
8 changed files
with
294 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import time | ||
import librosa | ||
|
||
from faster_whisper import WhisperModel | ||
|
||
device = "cuda" | ||
device_index = 0 | ||
|
||
models = [ | ||
"guillaumekln/faster-whisper-tiny.en", | ||
"guillaumekln/faster-whisper-small.en", | ||
"guillaumekln/faster-whisper-base.en", | ||
"guillaumekln/faster-whisper-medium.en", | ||
"guillaumekln/faster-whisper-large-v1", | ||
"guillaumekln/faster-whisper-large-v2", | ||
] | ||
|
||
n_batches = 3 | ||
warmup_batches = 5 | ||
|
||
audio_file = "4469669.mp3" | ||
max_len = 600 # 10 minutes | ||
|
||
|
||
def pre_process_audio(audio_file, sr, max_len): | ||
_, _sr = librosa.load(audio_file, sr=sr) | ||
audio_len = int(max_len * _sr) | ||
audio_arr = _[:audio_len] | ||
return {"raw": audio_arr, "sampling_rate": _sr}, audio_len | ||
|
||
|
||
audio_dict, audio_len = pre_process_audio(audio_file, 16000, max_len) | ||
|
||
rtfs = [] | ||
|
||
for model in models[:1]: | ||
asr_model = WhisperModel( | ||
model_size_or_path=model, | ||
device=device, | ||
device_index=device_index, | ||
compute_type="float16", | ||
) | ||
|
||
for i in range(3): | ||
print(f"outer_loop -> {i}") | ||
total_time = 0.0 | ||
for _ in range(n_batches + warmup_batches): | ||
print(f"batch_num -> {_}") | ||
start = time.time() | ||
segments, _ = asr_model.transcribe(audio_dict["raw"], language="en") | ||
_ = [segment._asdict() for segment in segments] # Iterate over segments to run inference | ||
end = time.time() | ||
if _ >= warmup_batches: | ||
total_time += end - start | ||
|
||
rtf = (total_time / n_batches) / (audio_len / 16000) | ||
rtfs.append(rtf) | ||
|
||
print(f"all RTFs: {model}: {rtfs}") | ||
rtf_val = sum(rtfs) / len(rtfs) | ||
print(f"avg. RTF: {model}: {rtf_val}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
"""Run evaluation for ctranslate2 whisper models.""""" | ||
import argparse | ||
import os | ||
|
||
import evaluate | ||
from faster_whisper import WhisperModel | ||
from tqdm import tqdm | ||
|
||
from normalizer import data_utils | ||
|
||
wer_metric = evaluate.load("wer") | ||
|
||
|
||
def dataset_iterator(dataset) -> dict: | ||
""" | ||
Iterate over the dataset and yield a dictionary with the audio and reference text. | ||
Args: | ||
dataset: dataset to iterate over | ||
Returns: | ||
dictionary: {"audio": audio, "reference": reference} | ||
""" | ||
for item in dataset: | ||
yield {**item["audio"], "reference": item["norm_text"]} | ||
|
||
|
||
def main(args) -> None: | ||
"""Main function to run evaluation on a dataset.""" | ||
asr_model = WhisperModel( | ||
model_size_or_path=args.model_id, | ||
compute_type="float16", | ||
device="cuda", | ||
device_index=args.device | ||
) | ||
|
||
dataset = data_utils.load_data(args) | ||
|
||
if args.max_eval_samples is not None and args.max_eval_samples > 0: | ||
print(f"Subsampling dataset to first {args.max_eval_samples} samples !") | ||
dataset = dataset.take(args.max_eval_samples) | ||
|
||
dataset = data_utils.prepare_data(dataset) | ||
|
||
predictions = [] | ||
references = [] | ||
|
||
# Run inference | ||
for batch in tqdm(dataset_iterator(dataset), desc=f"Evaluating {args.model_id}"): | ||
segments, _ = asr_model.transcribe(batch["array"], language="en") | ||
outputs = [segment._asdict() for segment in segments] | ||
predictions.extend( | ||
data_utils.normalizer( | ||
"".join([segment["text"] for segment in outputs]) | ||
).strip() | ||
) | ||
references.extend(batch["reference"][0]) | ||
|
||
# Write manifest results | ||
manifest_path = data_utils.write_manifest( | ||
references, predictions, args.model_id, args.dataset_path, args.dataset, args.split | ||
) | ||
print("Results saved at path:", os.path.abspath(manifest_path)) | ||
|
||
wer = wer_metric.compute(references=references, predictions=predictions) | ||
wer = round(100 * wer, 2) | ||
|
||
print("WER:", wer, "%") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--model_id", | ||
type=str, | ||
required=True, | ||
help="Model identifier. Should be loadable with 🤗 Transformers", | ||
) | ||
parser.add_argument( | ||
'--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`' | ||
) | ||
parser.add_argument( | ||
"--dataset", | ||
type=str, | ||
required=True, | ||
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names " | ||
"can be found at `https://huggingface.co/datasets/esb/datasets`" | ||
) | ||
parser.add_argument( | ||
"--split", | ||
type=str, | ||
default="test", | ||
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", | ||
) | ||
parser.add_argument( | ||
"--device", | ||
type=int, | ||
default=-1, | ||
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", | ||
) | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=16, | ||
help="Number of samples to go through each streamed batch.", | ||
) | ||
parser.add_argument( | ||
"--max_eval_samples", | ||
type=int, | ||
default=None, | ||
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", | ||
) | ||
parser.add_argument( | ||
"--no-streaming", | ||
dest='streaming', | ||
action="store_false", | ||
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", | ||
) | ||
args = parser.parse_args() | ||
parser.set_defaults(streaming=False) | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#!/bin/bash | ||
|
||
export PYTHONPATH="..":$PYTHONPATH | ||
|
||
MODEL_IDs=("guillaumekln/faster-whisper-tiny.en" "guillaumekln/faster-whisper-small.en" "guillaumekln/faster-whisper-base.en" "guillaumekln/faster-whisper-medium.en" "guillaumekln/faster-whisper-large-v1" "guillaumekln/faster-whisper-large-v2") | ||
BATCH_SIZE=1 | ||
DEVICE_INDEX=0 | ||
|
||
num_models=${#MODEL_IDs[@]} | ||
|
||
for (( i=0; i<${num_models}; i++ )); | ||
do | ||
MODEL_ID=${MODEL_IDs[$i]} | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="ami" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="earnings22" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="gigaspeech" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="librispeech" \ | ||
--split="test.clean" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="librispeech" \ | ||
--split="test.other" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="spgispeech" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="tedlium" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="voxpopuli" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
python run_eval.py \ | ||
--model_id=${MODEL_ID} \ | ||
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \ | ||
--dataset="common_voice" \ | ||
--split="test" \ | ||
--device=${DEVICE_INDEX} \ | ||
--batch_size=${BATCH_SIZE} \ | ||
--max_eval_samples=-1 | ||
|
||
# Evaluate results | ||
RUNDIR=`pwd` && \ | ||
cd ../normalizer && \ | ||
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \ | ||
cd $RUNDIR | ||
|
||
done |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,3 @@ evaluate | |
datasets | ||
librosa | ||
jiwer | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
datasets | ||
evaluate | ||
faster-whisper>=0.8.0 | ||
jiwer | ||
librosa |