-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support computing nbest oracle WER. #10
Changes from 3 commits
401c1c5
38d0604
0fa4875
f731996
a73d3ed
eae1674
fb1d284
f841581
3dadffd
60211ce
acefc70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
# How to use a pre-trained model to transcript a sound file | ||
|
||
You need to prepare 4 files: | ||
|
||
- a model checkpoint file, e.g., epoch-20.pt | ||
- HLG.pt, the decoding graph | ||
- words.txt, the word symbol table | ||
- a sound file, whose sampling rate has to be 16 kHz | ||
Supported formats are those supported by `torchaudio.load()`, | ||
e.g., wav and flac. | ||
|
||
Also, you need to install `kaldifeat`. Please refer to | ||
<https://github.com/csukuangfj/kaldifeat> for installation. | ||
|
||
Once you have the above files ready and have `kaldifeat` installed, | ||
you can run: | ||
|
||
``` | ||
./conformer_ctc/pretrained.py \ | ||
--checkpoint /path/to/your/checkpoint.pt \ | ||
--words-file /path/to/words.txt \ | ||
--hlg /path/to/HLG.pt \ | ||
--sound-file /path/to/your/sound.wav | ||
``` | ||
|
||
and you will see the transcribed result. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from icefall.decode import ( | ||
get_lattice, | ||
nbest_decoding, | ||
nbest_oracle, | ||
one_best_decoding, | ||
rescore_with_attention_decoder, | ||
rescore_with_n_best_list, | ||
|
@@ -56,6 +57,18 @@ def get_parser(): | |
"consecutive checkpoints before the checkpoint specified by " | ||
"'--epoch'. ", | ||
) | ||
|
||
parser.add_argument( | ||
"--scale", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this scale is only used for the nbest-oracle mode, perhaps that should be clarified, e.g. via the name and the documentation? Right now it's a bit unclear whether this would affect other things There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is also useful for other n-best rescoring methods, e.g., attention-decoder rescoring. Tuning this value can I'm adding more documentation to clarify its usage. |
||
type=float, | ||
default=1.0, | ||
help="The scale to be applied to `lattice.scores`." | ||
"It's needed if you use any kinds of n-best based rescoring. " | ||
"Currently, it is used when the decoding method is: nbest, " | ||
"nbest-rescoring, attention-decoder, and nbest-oracle. " | ||
"A smaller value results in more unique paths.", | ||
) | ||
|
||
return parser | ||
|
||
|
||
|
@@ -85,10 +98,14 @@ def get_params() -> AttributeDict: | |
# - nbest-rescoring | ||
# - whole-lattice-rescoring | ||
# - attention-decoder | ||
# - nbest-oracle | ||
# "method": "nbest", | ||
# "method": "nbest-rescoring", | ||
# "method": "whole-lattice-rescoring", | ||
"method": "attention-decoder", | ||
# "method": "nbest-oracle", | ||
# num_paths is used when method is "nbest", "nbest-rescoring", | ||
# and attention-decoder | ||
# attention-decoder, and nbest-oracle | ||
"num_paths": 100, | ||
} | ||
) | ||
|
@@ -179,6 +196,19 @@ def decode_one_batch( | |
subsampling_factor=params.subsampling_factor, | ||
) | ||
|
||
if params.method == "nbest-oracle": | ||
# Note: You can also pass rescored lattices to it. | ||
# We choose the HLG decoded lattice for speed reasons | ||
# as HLG decoding is faster and the oracle WER | ||
# is slightly worse than that of rescored lattices. | ||
return nbest_oracle( | ||
lattice=lattice, | ||
num_paths=params.num_paths, | ||
ref_texts=supervisions["text"], | ||
lexicon=lexicon, | ||
scale=params.scale, | ||
) | ||
|
||
if params.method in ["1best", "nbest"]: | ||
if params.method == "1best": | ||
best_path = one_best_decoding( | ||
|
@@ -190,8 +220,9 @@ def decode_one_batch( | |
lattice=lattice, | ||
num_paths=params.num_paths, | ||
use_double_scores=params.use_double_scores, | ||
scale=params.scale, | ||
) | ||
key = f"no_rescore-{params.num_paths}" | ||
key = f"no_rescore-scale-{params.scale}-{params.num_paths}" | ||
|
||
hyps = get_texts(best_path) | ||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] | ||
|
@@ -212,6 +243,7 @@ def decode_one_batch( | |
G=G, | ||
num_paths=params.num_paths, | ||
lm_scale_list=lm_scale_list, | ||
scale=params.scale, | ||
) | ||
elif params.method == "whole-lattice-rescoring": | ||
best_path_dict = rescore_with_whole_lattice( | ||
|
@@ -231,6 +263,7 @@ def decode_one_batch( | |
memory_key_padding_mask=memory_key_padding_mask, | ||
sos_id=sos_id, | ||
eos_id=eos_id, | ||
scale=params.scale, | ||
) | ||
else: | ||
assert False, f"Unsupported decoding method: {params.method}" | ||
|
@@ -284,7 +317,6 @@ def decode_dataset( | |
results = [] | ||
|
||
num_cuts = 0 | ||
tot_num_cuts = len(dl.dataset.cuts) | ||
|
||
results = defaultdict(list) | ||
for batch_idx, batch in enumerate(dl): | ||
|
@@ -315,8 +347,7 @@ def decode_dataset( | |
if batch_idx % 100 == 0: | ||
logging.info( | ||
f"batch {batch_idx}, cuts processed until now is " | ||
f"{num_cuts}/{tot_num_cuts} " | ||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" | ||
f"{num_cuts} " | ||
) | ||
return results | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import logging | ||
|
||
import k2 | ||
import kaldifeat | ||
import torch | ||
import torchaudio | ||
from conformer import Conformer | ||
|
||
from icefall.decode import ( | ||
get_lattice, | ||
one_best_decoding, | ||
) | ||
from icefall.utils import AttributeDict, get_texts | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
|
||
parser.add_argument( | ||
"--checkpoint", | ||
type=str, | ||
required=True, | ||
help="Path to the checkpoint." | ||
"The checkpoint is assume to be saved by " | ||
"icefall.checkpoint.save_checkpoint().", | ||
) | ||
|
||
parser.add_argument( | ||
"--words-file", | ||
type=str, | ||
required=True, | ||
help="Path to words.txt", | ||
) | ||
|
||
parser.add_argument( | ||
"--hlg", type=str, required=True, help="Path to HLG.pt." | ||
) | ||
|
||
parser.add_argument( | ||
"--sound-file", | ||
type=str, | ||
required=True, | ||
help="The input sound file to transcribe. " | ||
"Supported formats are those that supported by torchaudio.load(). " | ||
"For example, wav, flac are supported. " | ||
"The sample rate has to be 16kHz.", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def get_params() -> AttributeDict: | ||
params = AttributeDict( | ||
{ | ||
"feature_dim": 80, | ||
"nhead": 8, | ||
"num_classes": 5000, | ||
"attention_dim": 512, | ||
"subsampling_factor": 4, | ||
"num_decoder_layers": 6, | ||
"vgg_frontend": False, | ||
"is_espnet_structure": True, | ||
"mmi_loss": False, | ||
"use_feat_batchnorm": True, | ||
"search_beam": 20, | ||
"output_beam": 8, | ||
"min_active_states": 30, | ||
"max_active_states": 10000, | ||
"use_double_scores": True, | ||
} | ||
) | ||
return params | ||
|
||
|
||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
|
||
params = get_params() | ||
params.update(vars(args)) | ||
|
||
device = torch.device("cpu") | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda", 0) | ||
|
||
logging.info(f"device: {device}") | ||
|
||
model = Conformer( | ||
num_features=params.feature_dim, | ||
nhead=params.nhead, | ||
d_model=params.attention_dim, | ||
num_classes=params.num_classes, | ||
subsampling_factor=params.subsampling_factor, | ||
num_decoder_layers=params.num_decoder_layers, | ||
vgg_frontend=params.vgg_frontend, | ||
is_espnet_structure=params.is_espnet_structure, | ||
mmi_loss=params.mmi_loss, | ||
use_feat_batchnorm=params.use_feat_batchnorm, | ||
) | ||
|
||
checkpoint = torch.load(args.checkpoint, map_location="cpu") | ||
model.load_state_dict(checkpoint["model"]) | ||
model.to(device) | ||
|
||
HLG = k2.Fsa.from_dict(torch.load(params.hlg)) | ||
HLG = HLG.to(device) | ||
|
||
model.to(device) | ||
|
||
wave, samp_freq = torchaudio.load(params.sound_file) | ||
wave = wave.squeeze().to(device) | ||
|
||
opts = kaldifeat.FbankOptions() | ||
opts.device = device | ||
opts.frame_opts.dither = 0 | ||
opts.frame_opts.snip_edges = False | ||
opts.frame_opts.samp_freq = samp_freq | ||
opts.mel_opts.num_bins = 80 | ||
|
||
fbank = kaldifeat.Fbank(opts) | ||
|
||
features = fbank(wave) | ||
features = features.unsqueeze(0) | ||
|
||
nnet_output, _, _ = model(features) | ||
supervision_segments = torch.tensor( | ||
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32 | ||
) | ||
|
||
lattice = get_lattice( | ||
nnet_output=nnet_output, | ||
HLG=HLG, | ||
supervision_segments=supervision_segments, | ||
search_beam=params.search_beam, | ||
output_beam=params.output_beam, | ||
min_active_states=params.min_active_states, | ||
max_active_states=params.max_active_states, | ||
subsampling_factor=params.subsampling_factor, | ||
) | ||
|
||
best_path = one_best_decoding( | ||
lattice=lattice, use_double_scores=params.use_double_scores | ||
) | ||
|
||
hyps = get_texts(best_path) | ||
word_sym_table = k2.SymbolTable.from_file(params.words_file) | ||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] | ||
logging.info(hyps) | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = ( | ||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
) | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how a pre-trained model can be used to transcribe a sound file.
@danpovey
It depends on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only HLG decoding with the transformer encoder output is added.Do we need to use the attention decoder for rescoring?This comment was marked as outdated.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great-- thanks!
Regarding using the attention decoder for rescoring-- yes, I'd like you to add that, because this will probably
be a main feature of the tutorial, and I think having good results is probably worthwhile.