Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support computing nbest oracle WER. #10

Merged
merged 11 commits into from
Aug 20, 2021
27 changes: 27 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

Copy link
Collaborator Author

@csukuangfj csukuangfj Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how a pre-trained model can be used to transcribe a sound file.
@danpovey

It depends on

  • torchaudio, for reading sound files
  • kaldifeat, for feature extraction

Copy link
Collaborator Author

@csukuangfj csukuangfj Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only HLG decoding with the transformer encoder output is added.
Do we need to use the attention decoder for rescoring?

This comment was marked as outdated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great-- thanks!
Regarding using the attention decoder for rescoring-- yes, I'd like you to add that, because this will probably
be a main feature of the tutorial, and I think having good results is probably worthwhile.

# How to use a pre-trained model to transcript a sound file

You need to prepare 4 files:

- a model checkpoint file, e.g., epoch-20.pt
- HLG.pt, the decoding graph
- words.txt, the word symbol table
- a sound file, whose sampling rate has to be 16 kHz
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.

Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.

Once you have the above files ready and have `kaldifeat` installed,
you can run:

```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
--sound-file /path/to/your/sound.wav
```

and you will see the transcribed result.
41 changes: 36 additions & 5 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
Expand Down Expand Up @@ -56,6 +57,18 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)

parser.add_argument(
"--scale",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this scale is only used for the nbest-oracle mode, perhaps that should be clarified, e.g. via the name and the documentation? Right now it's a bit unclear whether this would affect other things

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is also useful for other n-best rescoring methods, e.g., attention-decoder rescoring. Tuning this value can
change the number of unique paths in an n-best list, which can potentially affect the final WER.

I'm adding more documentation to clarify its usage.

type=float,
default=1.0,
help="The scale to be applied to `lattice.scores`."
"It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, "
"nbest-rescoring, attention-decoder, and nbest-oracle. "
"A smaller value results in more unique paths.",
)

return parser


Expand Down Expand Up @@ -85,10 +98,14 @@ def get_params() -> AttributeDict:
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# - nbest-oracle
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
# attention-decoder, and nbest-oracle
"num_paths": 100,
}
)
Expand Down Expand Up @@ -179,6 +196,19 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor,
)

if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
lexicon=lexicon,
scale=params.scale,
)

if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
Expand All @@ -190,8 +220,9 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.scale,
)
key = f"no_rescore-{params.num_paths}"
key = f"no_rescore-scale-{params.scale}-{params.num_paths}"

hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
Expand All @@ -212,6 +243,7 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
Expand All @@ -231,6 +263,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
Expand Down Expand Up @@ -284,7 +317,6 @@ def decode_dataset(
results = []

num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)

results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
Expand Down Expand Up @@ -315,8 +347,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"{num_cuts} "
)
return results

Expand Down
162 changes: 162 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3

import argparse
import logging

import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer

from icefall.decode import (
get_lattice,
one_best_decoding,
)
from icefall.utils import AttributeDict, get_texts


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint."
"The checkpoint is assume to be saved by "
"icefall.checkpoint.save_checkpoint().",
)

parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)

parser.add_argument(
"--hlg", type=str, required=True, help="Path to HLG.pt."
)

parser.add_argument(
"--sound-file",
type=str,
required=True,
help="The input sound file to transcribe. "
"Supported formats are those that supported by torchaudio.load(). "
"For example, wav, flac are supported. "
"The sample rate has to be 16kHz.",
)

return parser


def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params


def main():
parser = get_parser()
args = parser.parse_args()

params = get_params()
params.update(vars(args))

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)

logging.info(f"device: {device}")

model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)

checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)

HLG = k2.Fsa.from_dict(torch.load(params.hlg))
HLG = HLG.to(device)

model.to(device)

wave, samp_freq = torchaudio.load(params.sound_file)
wave = wave.squeeze().to(device)

opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = samp_freq
opts.mel_opts.num_bins = 80

fbank = kaldifeat.Fbank(opts)

features = fbank(wave)
features = features.unsqueeze(0)

nnet_output, _, _ = model(features)
supervision_segments = torch.tensor(
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32
)

lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)

best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)

hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
logging.info(hyps)


if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)
main()
Loading