Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support computing nbest oracle WER. #10

Merged
merged 11 commits into from
Aug 20, 2021
19 changes: 17 additions & 2 deletions egs/librispeech/ASR/conformer_ctc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,30 @@ You need to prepare 4 files:
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.

Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.

Once you have the above files ready, you can run:
Once you have the above files ready and have `kaldifeat` installed,
you can run:

```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
--sound-file /path/to/your/sound.wav
/path/to/your/sound.wav
```

and you will see the transcribed result.

If you want to transcribe multiple files at the same time, you can use:

```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav \
```
81 changes: 64 additions & 17 deletions egs/librispeech/ASR/conformer_ctc/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import argparse
import logging
import math
from typing import List

import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence

from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
Expand Down Expand Up @@ -38,10 +42,10 @@ def get_parser():
)

parser.add_argument(
"--sound-file",
"sound_files",
type=str,
required=True,
help="The input sound file to transcribe. "
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those that supported by torchaudio.load(). "
"For example, wav, flac are supported. "
"The sample rate has to be 16kHz.",
Expand All @@ -56,7 +60,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_freq": 16000,
"sample_rate": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
Expand All @@ -74,6 +78,30 @@ def get_params() -> AttributeDict:
return params


def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans


def main():
parser = get_parser()
args = parser.parse_args()
Expand All @@ -87,6 +115,7 @@ def main():

logging.info(f"device: {device}")

logging.info("Create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
Expand All @@ -103,28 +132,39 @@ def main():
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()

HLG = k2.Fsa.from_dict(torch.load(params.hlg))
HLG = HLG.to(device)

model.to(device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim

wave, sample_freq = torchaudio.load(params.sound_file)
assert sample_freq == params.sample_freq
wave = wave.to(device)
fbank = kaldifeat.Fbank(opts)

features = torchaudio.compliance.kaldi.fbank(
waveform=wave,
num_mel_bins=params.feature_dim,
snip_edges=False,
sample_frequency=params.sample_freq,
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]

features = features.unsqueeze(0)
logging.info(f"Decoding started")
features = fbank(waves)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing torchaudio.compliance.kaldi with kaldifeat
since it is easier to extract features for multiple sound
files at the same time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I still have adding kaldifeat to Lhotse on my radar. I might remove all other kaldi-related feature extractors at the same time. But I think I won’t be able to do it before the tutorial.


nnet_output, _, _ = model(features)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)

with torch.no_grad():
nnet_output, _, _ = model(features)

batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)

lattice = get_lattice(
Expand All @@ -145,7 +185,14 @@ def main():
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
logging.info(hyps)

s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)

logging.info(f"Decoding Done")


if __name__ == "__main__":
Expand Down