From c9222bdb09a9524a9f3a54588f9ba79924903762 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 31 Jul 2021 15:55:42 +0800 Subject: [PATCH 01/19] Fix an error in TDNN-LSTM training. --- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index d94a2f7258..3330b07a5c 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -501,8 +501,9 @@ def run(rank, world_size, args): ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + scheduler.load_state_dict(checkpoints["scheduler"]) librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() From 1fa30998da5ac06d4c742227cc949ed68c256df7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 31 Jul 2021 20:24:47 +0800 Subject: [PATCH 02/19] WIP: Refactoring --- .gitignore | 1 + egs/librispeech/ASR/conformer_ctc/decode.py | 6 +- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- .../ASR/conformer_ctc/transformer.py | 12 +- egs/librispeech/ASR/local/compile_hlg.py | 34 ++---- .../ASR/local/compute_fbank_librispeech.py | 18 ++- .../ASR/local/compute_fbank_musan.py | 15 ++- egs/librispeech/ASR/local/download_lm.py | 52 +++++++-- egs/librispeech/ASR/local/prepare_lang.py | 10 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 16 ++- egs/librispeech/ASR/local/train_bpe_model.py | 9 +- egs/librispeech/ASR/prepare.sh | 103 +++++++++++------- egs/librispeech/ASR/shared | 1 + egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 2 +- icefall/lexicon.py | 30 +++-- .../local => icefall/shared}/parse_options.sh | 0 17 files changed, 195 insertions(+), 122 deletions(-) create mode 120000 egs/librispeech/ASR/shared rename {egs/librispeech/ASR/local => icefall/shared}/parse_options.sh (100%) diff --git a/.gitignore b/.gitignore index 6cb9f22997..839a1c34a3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ path.sh exp exp*/ *.pt +download/ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index d1cbc14de9..3a8db1b818 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -62,7 +62,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), + "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -367,15 +367,13 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt")) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() - # HLG = k2.ctc_topo(4999).to(device) - if params.method in ( "nbest-rescoring", "whole-lattice-rescoring", diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 40d3cf7fbb..d411a37834 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -125,7 +125,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), + "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 0.0, "subsampling_factor": 4, diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 1df16e3467..06027cf64a 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -188,7 +188,7 @@ def decoder_forward( encoder_mask: Tensor, supervision: Supervisions = None, graph_compiler: object = None, - token_ids: List[int] = None, + token_ids: List[List[int]] = None, sos_id: Optional[int] = None, eos_id: Optional[int] = None, ) -> Tensor: @@ -199,6 +199,7 @@ def decoder_forward( supervision: Supervison in lhotse format, get from batch['supervisions'] graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) , graph_compiler.words and graph_compiler.oov + token_ids: A list of lists. Each list contains word piece IDs for an utterance. sos_id: sos token id eos_id: eos token id @@ -210,7 +211,10 @@ def decoder_forward( supervision, graph_compiler.lexicon.words, graph_compiler.oov ) ys_in_pad, ys_out_pad = add_sos_eos( - batch_text, graph_compiler.L_inv, sos_id, eos_id, + batch_text, + graph_compiler.L_inv, + sos_id, + eos_id, ) elif token_ids is not None: _sos = torch.tensor([sos_id]) @@ -225,7 +229,7 @@ def decoder_forward( ys_out_pad = pad_list(ys_out, -1) else: - raise ValueError("Invalid input for decoder self attetion") + raise ValueError("Invalid input for decoder self attention") ys_in_pad = ys_in_pad.to(x.device) ys_out_pad = ys_out_pad.to(x.device) @@ -284,7 +288,7 @@ def decoder_nll( ys_in_pad = pad_list(ys_in, eos_id) ys_out_pad = pad_list(ys_out, -1) else: - raise ValueError("Invalid input for decoder self attetion") + raise ValueError("Invalid input for decoder self attention") ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 605d72daed..c02fb7c0db 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: - The language directory, e.g., data/lang or data/lang/bpe. + The language directory, e.g., data/lang_phone or data/lang_bpe. Return: An FSA representing HLG. @@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: return HLG -def phone_based_HLG(): - if Path("data/lm/HLG.pt").is_file(): - return - - logging.info("Compiling phone based HLG") - HLG = compile_HLG("data/lang") - - logging.info("Saving HLG.pt to data/lm") - torch.save(HLG.as_dict(), "data/lm/HLG.pt") - - -def bpe_based_HLG(): - if Path("data/lm/HLG_bpe.pt").is_file(): - return - - logging.info("Compiling BPE based HLG") - HLG = compile_HLG("data/lang/bpe") - logging.info("Saving HLG_bpe.pt to data/lm") - torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt") +def main(): + for d in ["data/lang_phone", "data/lang_bpe"]: + d = Path(d) + logging.info(f"Processing {d}") + if (d / "HLG.pt").is_file(): + logging.info(f"{d}/HLG.pt already exists - skipping") + continue -def main(): - phone_based_HLG() - bpe_based_HLG() + HLG = compile_HLG(d) + logging.info(f"Saving HLG.pt to {d}") + torch.save(HLG.as_dict(), f"{d}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 947d9f8d9d..0c07aaa1ab 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 """ -This file computes fbank features of the librispeech dataset. -Its looks for manifests in the directory data/manifests -and generated fbank features are saved in data/fbank. +This file computes fbank features of the LibriSpeech dataset. +Its looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. """ +import logging import os from pathlib import Path @@ -40,9 +42,9 @@ def compute_fbank_librispeech(): with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): if (output_dir / f"cuts_{partition}.json.gz").is_file(): - print(f"{partition} already exists - skipping.") + logging.info(f"{partition} already exists - skipping.") continue - print("Processing", partition) + logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( recordings=m["recordings"], supervisions=m["supervisions"], @@ -65,4 +67,10 @@ def compute_fbank_librispeech(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_librispeech() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index d63131da89..6a46e6978a 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -2,10 +2,12 @@ """ This file computes fbank features of the musan dataset. -Its looks for manifests in the directory data/manifests -and generated fbank features are saved in data/fbank. +Its looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. """ +import logging import os from pathlib import Path @@ -34,10 +36,10 @@ def compute_fbank_musan(): musan_cuts_path = output_dir / "cuts_musan.json.gz" if musan_cuts_path.is_file(): - print(f"{musan_cuts_path} already exists - skipping") + logging.info(f"{musan_cuts_path} already exists - skipping") return - print("Extracting features for Musan") + logging.info("Extracting features for Musan") extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -63,4 +65,9 @@ def compute_fbank_musan(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 0bdc2935ba..5c9e2a6751 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -2,10 +2,25 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This file downloads librispeech LM files to data/lm +This file downloads the following LibriSpeech LM files: + + - 3-gram.pruned.1e-7.arpa.gz + - 4-gram.arpa.gz + - librispeech-vocab.txt + - librispeech-lexicon.txt + +from http://www.openslr.org/resources/11 +and save them in the user provided directory. + +Files are not re-downloaded if they already exist. + +Usage: + ./local/download_lm.py --out-dir ./download/lm """ +import argparse import gzip +import logging import os import shutil from pathlib import Path @@ -14,9 +29,17 @@ from tqdm.auto import tqdm -def download_lm(): +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, help="Output directory.") + + args = parser.parse_args() + return args + + +def main(out_dir: str): url = "http://www.openslr.org/resources/11" - target_dir = Path("data/lm") + out_dir = Path(out_dir) files_to_download = ( "3-gram.pruned.1e-7.arpa.gz", @@ -26,7 +49,7 @@ def download_lm(): ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): - filename = target_dir / f + filename = out_dir / f if filename.is_file() is False: urlretrieve_progress( f"{url}/{f}", @@ -34,17 +57,26 @@ def download_lm(): desc=f"Downloading {filename}", ) else: - print(f"{filename} already exists - skipping") + logging.info(f"{filename} already exists - skipping") if ".gz" in str(filename): - unzip_file = Path(os.path.splitext(filename)[0]) - if unzip_file.is_file() is False: + unzipped = Path(os.path.splitext(filename)[0]) + if unzipped.is_file() is False: with gzip.open(filename, "rb") as f_in: - with open(unzip_file, "wb") as f_out: + with open(unzipped, "wb") as f_out: shutil.copyfileobj(f_in, f_out) else: - print(f"{unzip_file} already exist - skipping") + logging.info(f"{unzipped} already exist - skipping") if __name__ == "__main__": - download_lm() + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(f"out_dir: {args.out_dir}") + + main(out_dir=args.out_dir) diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index b9d13f5bb4..f7fde7796f 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -3,7 +3,7 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This script takes as input a lexicon file "data/lang/lexicon.txt" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" consisting of words and tokens (i.e., phones) and does the following: 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt @@ -20,8 +20,6 @@ 5. Generate L_disambig.pt, in k2 format. """ import math -import re -import sys from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple @@ -284,7 +282,9 @@ def lexicon_to_fst( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -301,7 +301,7 @@ def lexicon_to_fst( def main(): - out_dir = Path("data/lang") + out_dir = Path("data/lang_phone") lexicon_filename = out_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index 0c3e9ede54..e31220d9b2 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -5,10 +5,10 @@ """ This script takes as inputs the following two files: - - data/lang/bpe/bpe.model, - - data/lang/bpe/words.txt + - data/lang_bpe/bpe.model, + - data/lang_bpe/words.txt -and generates the following files in the directory data/lang/bpe: +and generates the following files in the directory data/lang_bpe: - lexicon.txt - lexicon_disambig.txt @@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -140,7 +142,7 @@ def generate_lexicon( def main(): - lang_dir = Path("data/lang/bpe") + lang_dir = Path("data/lang_bpe") model_file = lang_dir / "bpe.model" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") @@ -173,7 +175,9 @@ def main(): write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( - lexicon, token2id=token_sym_table, word2id=word_sym_table, + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, ) L_disambig = lexicon_to_fst_no_sil( diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index b5c6c7541a..59746ad9a6 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -14,18 +14,17 @@ # # Please install a version >=0.1.96 +import shutil from pathlib import Path import sentencepiece as spm -import shutil - def main(): model_type = "unigram" vocab_size = 5000 - model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}" - train_text = "data/lang/bpe/train.txt" + model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" + train_text = "data/lang_bpe/train.txt" character_coverage = 1.0 input_sentence_size = 100000000 @@ -53,7 +52,7 @@ def main(): sp = spm.SentencePieceProcessor(model_file=str(model_file)) vocab_size = sp.vocab_size() - shutil.copyfile(model_file, "data/lang/bpe/bpe.model") + shutil.copyfile(model_file, "data/lang_bpe/bpe.model") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 406527b713..ae676b199b 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -6,8 +6,38 @@ nj=15 stage=-1 stop_stage=100 -. local/parse_options.sh || exit 1 - +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# +# - $do_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + + +# All generated files by this script are saved in "data" mkdir -p data log() { @@ -16,10 +46,11 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "dl_dir: $dl_dir" + if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "stage -1: Download LM" - mkdir -p data/lm - ./local/download_lm.py + ./local/download_lm.py --out-dir=$dl_dir/lm fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then @@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink # - # ln -sfv /path/to/LibriSpeech data/ - # - # The script checks that if - # - # data/LibriSpeech/test-clean/.completed exists, + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech # - # it will not re-download it. - # - # The same goes for dev-clean, dev-other, test-other, train-clean-100 - # train-clean-360, and train-other-500 - - mkdir -p data/LibriSpeech - lhotse download librispeech --full data + if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then + lhotse download librispeech --full $dl_dir + fi # If you have pre-downloaded it to /path/to/musan, # you can create a symlink # - # ln -sfv /path/to/musan data/ + # ln -sfv /path/to/musan $dl_dir/ # - # and create a file data/.musan_completed - # to avoid downloading it again - if [ ! -f data/.musan_completed ]; then - lhotse download musan data + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare librispeech manifest" - # We assume that you have downloaded the librispeech corpus - # to data/LibriSpeech + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech mkdir -p data/manifests - lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # We assume that you have downloaded the musan corpus # to data/musan mkdir -p data/manifests - lhotse prepare musan data/musan data/manifests + lhotse prepare musan $dl_dir/musan data/manifests fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then @@ -84,24 +105,25 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" - # TODO: add BPE based lang - mkdir -p data/lang + mkdir -p data/lang_phone (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - data/lm/librispeech-lexicon.txt | - sort | uniq > data/lang/lexicon.txt + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > data/lang_phone/lexicon.txt - if [ ! -f data/lang/L_disambig.pt ]; then + if [ ! -f data/lang_phone/L_disambig.pt ]; then ./local/prepare_lang.py fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "State 6: Prepare BPE based lang" - mkdir -p data/lang/bpe - cp data/lang/words.txt data/lang/bpe/ + mkdir -p data/lang_bpe + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt data/lang_bpe/ - if [ ! -f data/lang/bpe/train.txt ]; then + if [ ! -f data/lang_bpe/train.txt ]; then log "Generate data for BPE training" files=$( find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" @@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then ) for f in ${files[@]}; do cat $f | cut -d " " -f 2- - done > data/lang/bpe/train.txt + done > data/lang_bpe/train.txt fi python3 ./local/train_bpe_model.py - if [ ! -f data/lang/bpe/L_disambig.pt ]; then + if [ ! -f data/lang_bpe/L_disambig.pt ]; then ./local/prepare_lang_bpe.py fi fi @@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm + mkdir -p data/lm if [ ! -f data/lm/G_3_gram.fst.txt ]; then # It is used in building HLG python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ + --read-symbol-table="data/lang_phone/words.txt" \ --disambig-symbol='#0' \ --max-order=3 \ - data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt fi if [ ! -f data/lm/G_4_gram.fst.txt ]; then # It is used for LM rescoring python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ + --read-symbol-table="data/lang_phone/words.txt" \ --disambig-symbol='#0' \ --max-order=4 \ - data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt fi fi diff --git a/egs/librispeech/ASR/shared b/egs/librispeech/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/librispeech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 2c45b4e317..137fa795c0 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -58,7 +58,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang"), + "lang_dir": Path("data/lang_phone"), "lm_dir": Path("data/lm"), "feature_dim": 80, "subsampling_factor": 3, @@ -328,7 +328,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt")) + HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -340,7 +340,7 @@ def main(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.words["#0"] + first_word_disambig_id = lexicon.word_table["#0"] G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 3330b07a5c..dbb9f64ecf 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -127,7 +127,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("tdnn_lstm_ctc/exp"), - "lang_dir": Path("data/lang"), + "lang_dir": Path("data/lang_phone"), "lr": 1e-3, "feature_dim": 80, "weight_decay": 5e-4, diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 3b52c70c92..89747b11b0 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -1,7 +1,8 @@ import logging import re +import sys from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Tuple import k2 import torch @@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]: continue if len(a) < 2: - print(f"Found bad line {line} in lexicon file {filename}") - print("Every line is expected to contain at least 2 fields") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info( + "Every line is expected to contain at least 2 fields" + ) sys.exit(1) word = a[0] if word == "": - print(f"Found bad line {line} in lexicon file {filename}") - print(" should not be a valid word") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info(" should not be a valid word") sys.exit(1) tokens = a[1:] @@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: class Lexicon(object): - """Phone based lexicon. - - TODO: Add BpeLexicon for BPE models. - """ + """Phone based lexicon.""" def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + self, + lang_dir: Path, + disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Args: @@ -121,7 +127,9 @@ def tokens(self) -> List[int]: class BpeLexicon(Lexicon): def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + self, + lang_dir: Path, + disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Refer to the help information in Lexicon.__init__. diff --git a/egs/librispeech/ASR/local/parse_options.sh b/icefall/shared/parse_options.sh similarity index 100% rename from egs/librispeech/ASR/local/parse_options.sh rename to icefall/shared/parse_options.sh From f6091b10c09ef7c32c94f1d426758e698f6db056 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 2 Aug 2021 23:48:26 +0800 Subject: [PATCH 03/19] Refactor transformer.py --- .../ASR/conformer_ctc/conformer.py | 24 +- egs/librispeech/ASR/conformer_ctc/decode.py | 12 +- .../ASR/conformer_ctc/subsampling.py | 144 ++++ .../ASR/conformer_ctc/test_subsampling.py | 33 + .../ASR/conformer_ctc/test_transformer.py | 36 + egs/librispeech/ASR/conformer_ctc/train.py | 23 +- .../ASR/conformer_ctc/transformer.py | 753 +++++++++--------- .../ASR/local/compute_fbank_librispeech.py | 10 +- .../ASR/local/compute_fbank_musan.py | 7 + icefall/decode.py | 39 +- 10 files changed, 685 insertions(+), 396 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/subsampling.py create mode 100755 egs/librispeech/ASR/conformer_ctc/test_subsampling.py create mode 100644 egs/librispeech/ASR/conformer_ctc/test_transformer.py diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1e82eff2fa..d3952d3b1a 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -89,15 +89,21 @@ def encode( ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. Returns: Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). Tensor: Mask tensor of dimension (batch_size, input_length) """ - x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) @@ -796,8 +802,7 @@ def multi_head_attention_forward( bsz, num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len @@ -867,12 +872,7 @@ def __init__( ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, + channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.activation = Swish() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3a8db1b818..9ebb76fa1d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -147,15 +147,10 @@ def decode_one_batch( feature = feature.to(device) # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, C, T] - - nnet_output = nnet_output.permute(0, 2, 1) - # now nnet_output is [N, T, C] + # nnet_output is [N, T, C] supervision_segments = torch.stack( ( @@ -227,6 +222,8 @@ def decode_one_batch( model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, + sos_id=lexicon.sos_id, + eos_id=lexicon.eos_id, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -468,5 +465,8 @@ def main(): logging.info("Done!") +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py new file mode 100644 index 0000000000..5c3e1222ef --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + # On entry, x is [N, T, idim] + x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + x = self.conv(x) + # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 0000000000..937845d779 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling +import torch + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py new file mode 100644 index 0000000000..a6569e8d76 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import torch +from transformer import Transformer, encoder_padding_mask + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d411a37834..552db81ecc 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -275,15 +275,13 @@ def compute_loss( device = graph_compiler.device feature = batch["inputs"] # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, C, T] - nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + # nnet_output is [N, T, C] # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -536,6 +534,22 @@ def train_one_epoch( f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_ctc_loss", + params.valid_ctc_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_att_loss", + params.valid_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) params.train_loss = tot_loss / tot_frames @@ -675,5 +689,8 @@ def main(): run(rank=0, world_size=1, args=args) +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 06027cf64a..b2123b8fcf 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 import math @@ -8,30 +6,16 @@ import k2 import torch -from torch import Tensor, nn +import torch.nn as nn +from subsampling import Conv2dSubsampling, VggSubsampling from icefall.utils import get_texts # Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, Tensor] +Supervisions = Dict[str, torch.Tensor] class Transformer(nn.Module): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - """ - def __init__( self, num_features: int, @@ -48,6 +32,36 @@ def __init__( mmi_loss: bool = True, use_feat_batchnorm: bool = False, ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + mmi_loss: + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ super().__init__() self.use_feat_batchnorm = use_feat_batchnorm if use_feat_batchnorm: @@ -59,18 +73,23 @@ def __init__( if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.encoder_embed = ( - VggSubsampling(num_features, d_model) - if vgg_frontend - else Conv2dSubsampling(num_features, d_model) - ) + # self.encoder_embed converts the input of shape [N, T, num_classes] + # to the shape [N, T//subsampling_factor, d_model]. + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_pos = PositionalEncoding(d_model, dropout) encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, normalize_before=normalize_before, ) @@ -80,9 +99,12 @@ def __init__( encoder_norm = None self.encoder = nn.TransformerEncoder( - encoder_layer, num_encoder_layers, encoder_norm + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, ) + # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) ) @@ -97,14 +119,16 @@ def __init__( self.num_classes ) # bpe model already has sos/eos symbol - self.decoder_embed = nn.Embedding(self.decoder_num_class, d_model) + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) self.decoder_pos = PositionalEncoding(d_model, dropout) decoder_layer = TransformerDecoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, normalize_before=normalize_before, ) @@ -114,7 +138,9 @@ def __init__( decoder_norm = None self.decoder = nn.TransformerDecoder( - decoder_layer, num_decoder_layers, decoder_norm + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, ) self.decoder_output_layer = torch.nn.Linear( @@ -126,93 +152,145 @@ def __init__( self.decoder_criterion = None def forward( - self, x: Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervision: Supervison in lhotse format, get from batch['supervisions'] + x: + The input tensor. Its shape is [N, T, C]. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) Returns: - Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). - Tensor: Before linear layer tensor of dimension (input_length, batch_size, d_model). - Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. - + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is [N, T, C] + - Encoder output with shape [T, N, C]. It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is [N, T]. + It is None if `supervision` is None. """ if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = self.feat_batchnorm(x) - encoder_memory, memory_mask = self.encode(x, supervision) + x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.encode(x, supervision) x = self.encoder_output(encoder_memory) - return x, encoder_memory, memory_mask + return x, encoder_memory, memory_key_padding_mask def encode( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] - + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. + Return a tuple with two tensors: + - The encoder output, with shape [T, N, C] + - encoder padding mask, with shape [N, T]. + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. """ - x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) - x = self.encoder_embed(x) x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask != None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, B, F) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) return x, mask - def encoder_output(self, x: Tensor) -> Tensor: + def encoder_output(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x: Tensor of dimension (input_length, batch_size, d_model). + x: + The output tensor from the transformer encoder. + Its shape is [T, N, C] Returns: - Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). + Return a tensor that can be used for CTC decoding. + Its shape is [N, T, C] """ - x = self.encoder_output_layer(x).permute( - 1, 2, 0 - ) # (T, B, F) ->(B, F, T) - x = nn.functional.log_softmax(x, dim=1) # (B, F, T) + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x def decoder_forward( self, - x: Tensor, - encoder_mask: Tensor, - supervision: Supervisions = None, - graph_compiler: object = None, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + supervision: Optional[Supervisions] = None, + L_inv: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, + oov_str: Optional[str] = None, token_ids: List[List[int]] = None, sos_id: Optional[int] = None, eos_id: Optional[int] = None, - ) -> Tensor: + ) -> torch.Tensor: """ + Note: + If phone based lexicon is used, the following arguments are required: + + - supervision + - L_inv + - word_table + - oov_str + + If BPE based lexicon is used, the following arguments are required: + + - token_ids + - sos_id + - eos_id + Args: - x: Tensor of dimension (input_length, batch_size, d_model). - encoder_mask: Mask tensor of dimension (batch_size, input_length) - supervision: Supervison in lhotse format, get from batch['supervisions'] - graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) - , graph_compiler.words and graph_compiler.oov - token_ids: A list of lists. Each list contains word piece IDs for an utterance. - sos_id: sos token id - eos_id: eos token id + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + L_inv: + It is an FSA with labels being word IDs and aux_labels being + token IDs (e.g., phone IDs or word piece IDs). + word_table: + Word table providing mapping between words and IDs. + oov_str: + The OOV word, e.g., '' + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id Returns: - Tensor: Decoder loss. + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. """ - if supervision is not None and graph_compiler is not None: + if supervision is not None and word_table is not None: batch_text = get_normal_transcripts( - supervision, graph_compiler.lexicon.words, graph_compiler.oov + supervision, word_table, oov_str ) ys_in_pad, ys_out_pad = add_sos_eos( batch_text, - graph_compiler.L_inv, + L_inv, sos_id, eos_id, ) @@ -227,31 +305,31 @@ def decoder_forward( ] ys_in_pad = pad_list(ys_in, eos_id) ys_out_pad = pad_list(ys_out, -1) - else: raise ValueError("Invalid input for decoder self attention") - ys_in_pad = ys_in_pad.to(x.device) - ys_out_pad = ys_out_pad.to(x.device) + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - x.device + device ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) pred_pad = self.decoder( tgt=tgt, - memory=x, + memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=encoder_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) @@ -259,23 +337,31 @@ def decoder_forward( def decoder_nll( self, - x: Tensor, - encoder_mask: Tensor, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, token_ids: List[List[int]], sos_id: int, eos_id: int, - ) -> Tensor: + ) -> torch.Tensor: """ Args: - x: encoder-output, Tensor of dimension (input_length, batch_size, d_model). - encoder_mask: Mask tensor of dimension (batch_size, input_length) - token_ids: n-best list extracted from lattice before rescore - + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. Returns: - Tensor: negative log-likelihood. + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). """ - # The common part between this fuction and decoder_forward could be - # extracted as a seperated function. + # The common part between this function and decoder_forward could be + # extracted as a separate function. if token_ids is not None: _sos = torch.tensor([sos_id]) _eos = torch.tensor([eos_id]) @@ -290,11 +376,12 @@ def decoder_nll( else: raise ValueError("Invalid input for decoder self attention") - ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - x.device + device ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) @@ -304,10 +391,10 @@ def decoder_nll( tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) pred_pad = self.decoder( tgt=tgt, - memory=x, + memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=encoder_mask, + memory_key_padding_mask=memory_key_padding_mask, ) # (T, B, F) pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) @@ -326,16 +413,24 @@ def decoder_nll( class TransformerEncoderLayer(nn.Module): """ - Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before, + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, i.e., use layer_norm before the first block. Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). - normalize_before: whether to use layer_norm before the first block. + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) @@ -375,23 +470,24 @@ def __setstate__(self, state): def forward( self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) Shape: src: (S, N, E). src_mask: (S, S). src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number """ residual = src if self.normalize_before: @@ -419,15 +515,22 @@ def forward( class TransformerDecoderLayer(nn.Module): """ - Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before, + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, i.e., use layer_norm before the first block. Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). Examples:: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) @@ -471,22 +574,28 @@ def __setstate__(self, state): def forward( self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. Args: - tgt: the sequence to the decoder layer (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). Shape: tgt: (T, N, E). @@ -495,7 +604,8 @@ def forward( memory_mask: (T, S). tgt_key_padding_mask: (N, T). memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number """ residual = tgt if self.normalize_before: @@ -546,164 +656,55 @@ def _get_activation_fn(activation: str): ) -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py - - Args: - idim: Input dimension. - odim: Output dimension. - - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a Conv2dSubsampling object.""" - super(Conv2dSubsampling, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: Tensor) -> Tensor: - """Subsample x. - - Args: - x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). - - Returns: - torch.Tensor: Subsampled tensor of dimension (batch_size, input_length, d_model). - where time' = time // 4. - - """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.conv(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf -class VggSubsampling(nn.Module): - """Trying to follow the setup described here https://arxiv.org/pdf/1910.09799.pdf - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - Args: - idim: Input dimension. - odim: Output dimension. + Note:: + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) """ - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. This uses 2 VGG blocks with 2 - Conv2d layers each, subsampling its input by a factor of 4 in the - time dimensions. - - Args: - idim: Number of features at input, e.g. 40 or 80 for MFCC - (will be treated as the image height). - odim: Output dimension (number of features), e.g. 256 + def __init__(self, d_model: int, dropout: float = 0.1) -> None: """ - super(VggSubsampling, self).__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the 2nd convolution, - # so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) - - def forward(self, x: Tensor) -> Tensor: - """Subsample x. - Args: - x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). - - Returns: - torch.Tensor: Subsampled tensor of dimension (batch_size, input_length', d_model). - where input_length' == (((input_length - 1) // 2) - 1) // 2 - + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x - - -class PositionalEncoding(nn.Module): - """ - Positional encoding. - - Args: - d_model: Embedding dimension. - dropout: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, d_model: int, dropout: float = 0.1, max_len: int = 5000 - ) -> None: - """Construct an PositionalEncoding object.""" - super(PositionalEncoding, self).__init__() + super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - pe = torch.zeros(x.size(1), self.d_model) + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) @@ -712,34 +713,44 @@ def extend_pe(self, x: Tensor) -> None: pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Add positional encoding. Args: - x: Input tensor of dimention (batch_size, input_length, d_model). + x: + Its shape is [N, T, C] Returns: - torch.Tensor: Encoded tensor of dimention (batch_size, input_length, d_model). - + Return a tensor of shape [N, T, C] """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1)] + x = x * self.xscale + self.pe[:, : x.size(1), :] return self.dropout(x) class Noam(object): """ - Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa Args: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups - model_size: attention dimension of the transformer model - factor: learning rate factor - warm_step: warmup steps + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps """ def __init__( @@ -812,7 +823,8 @@ class LabelSmoothingLoss(nn.Module): """ Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa Args: size: the number of class @@ -841,19 +853,23 @@ def __init__( self.true_dist = None self.normalize_length = normalize_length - def forward(self, x: Tensor, target: Tensor) -> Tensor: + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute loss between x and target. Args: - x: prediction of dimention (batch_size, input_length, number_of_classes). - target: target masked with self.padding_id of dimention (batch_size, input_length). + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). Returns: - torch.Tensor: scalar float value + A scalar tensor containing the loss without normalization. """ assert x.size(2) == self.size - batch_size = x.size(0) + # batch_size = x.size(0) x = x.view(-1, self.size) target = target.view(-1) with torch.no_grad(): @@ -871,12 +887,23 @@ def forward(self, x: Tensor, target: Tensor) -> Tensor: def encoder_padding_mask( max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[Tensor]: - """Make mask tensor containing indices of padded part. +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. Args: - max_len: maximum length of input features - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) Returns: Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. @@ -916,16 +943,23 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor: - """Generate a length mask for input. The masked position are filled with bool(True), - Unmasked positions are filled with bool(False). +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with bool(True), + Unmasked positions are filled with bool(False). Args: - ys_pad: padded tensor of dimension (batch_size, input_length). - ignore_id: the ignored number (the padding number) in ys_pad + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad Returns: - Tensor: a mask tensor of dimension (batch_size, input_length). + Tensor: + a bool tensor of the same shape as the input tensor. """ ys_mask = ys_pad == ignore_id return ys_mask @@ -934,13 +968,20 @@ def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor: def get_normal_transcripts( supervision: Supervisions, words: k2.SymbolTable, oov: str = "" ) -> List[List[int]]: - """Get normal transcripts (1 input recording has 1 transcript) from lhotse cut format. - Achieved by concatenate the transcripts corresponding to the same recording. + """Get normal transcripts (1 input recording has 1 transcript) + from lhotse cut format. + + Achieved by concatenating the transcripts corresponding to the + same recording. Args: - supervision : Supervison in lhotse format, i.e., batch['supervisions'] - words: The word symbol table. - oov: Out of vocabulary word. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + words: + The word symbol table. + oov: + Out of vocabulary word. Returns: List[List[int]]: List of concatenated transcripts, length is batch_size @@ -960,15 +1001,15 @@ def get_normal_transcripts( return batch_text -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). - Unmasked positions are filled with float(0.0). +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). Args: - sz: mask size + sz: mask size Returns: - Tensor: a square mask of dimension (sz, sz) + A square mask of dimension (sz, sz) """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = ( @@ -981,39 +1022,49 @@ def generate_square_subsequent_mask(sz: int) -> Tensor: def add_sos_eos( ys: List[List[int]], - lexicon: k2.Fsa, + L_inv: k2.Fsa, sos_id: int, eos_id: int, ignore_id: int = -1, -) -> Tuple[Tensor, Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """Add and labels. Args: - ys: batch of unpadded target sequences - lexicon: Its labels are words, while its aux_labels are phones. - sos_id: index of - eos_id: index of - ignore_id: index of padding + ys: + Batch of unpadded target sequences (i.e., word IDs) + L_inv: + Its labels are words, while its aux_labels are tokens. + sos_id: + index of + eos_id: + index of + ignore_id: + value for padding Returns: - Tensor: Input of transformer decoder. Padded tensor of dimention (batch_size, max_length). - Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length). + Return a tuple containing two tensors: + - Input of transformer decoder. + Padded tensor of dimension (batch_size, max_length). + - Output of transformer decoder. + Padded tensor of dimension (batch_size, max_length). """ _sos = torch.tensor([sos_id]) _eos = torch.tensor([eos_id]) - ys = get_hierarchical_targets(ys, lexicon) + ys = get_hierarchical_targets(ys, L_inv) ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] - return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) + return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id) -def pad_list(ys: List[Tensor], pad_value: float) -> Tensor: +def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor: """Perform padding for the list of tensors. Args: - ys: List of tensors. len(ys) = batch_size. - pad_value: Value for padding. + ys: + List of tensors. len(ys) = batch_size. + pad_value: + Value for padding. Returns: Tensor: Padded tensor (batch_size, max_length, `*`). @@ -1039,25 +1090,25 @@ def pad_list(ys: List[Tensor], pad_value: float) -> Tensor: def get_hierarchical_targets( - ys: List[List[int]], lexicon: k2.Fsa -) -> List[Tensor]: - """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). + ys: List[List[int]], L_inv: Optional[k2.Fsa] = None +) -> List[torch.Tensor]: + """Get hierarchical transcripts (i.e., phone level transcripts) from + transcripts (i.e., word level transcripts). Args: - ys: Word level transcripts. - lexicon: Its labels are words, while its aux_labels are phones. + ys: + Word level transcripts. Each sublist is a transcript of an utterance. + L_inv: + Its labels are words, while its aux_labels are tokens. Returns: - List[Tensor]: Phone level transcripts. - + List[torch.Tensor]: + Token level transcripts. """ - if lexicon is None: - return ys - else: - L_inv = lexicon + if L_inv is None: + return [torch.tensor(y) for y in ys] - n_batch = len(ys) device = L_inv.device transcripts = k2.create_fsa_vec( @@ -1081,19 +1132,3 @@ def get_hierarchical_targets( ys = [torch.tensor(y) for y in ys] return ys - - -def test_transformer(): - t = Transformer(40, 1281) - T = 200 - f = torch.rand(31, 40, T) - g, _, _ = t(f) - assert g.shape == (31, 1281, (((T - 1) // 2) - 1) // 2) - - -def main(): - test_transformer() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 0c07aaa1ab..d81096070f 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -11,11 +11,18 @@ import os from pathlib import Path +import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def compute_fbank_librispeech(): src_dir = Path("data/manifests") @@ -46,8 +53,7 @@ def compute_fbank_librispeech(): continue logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], + recordings=m["recordings"], supervisions=m["supervisions"], ) if "train" in partition: cut_set = ( diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 6a46e6978a..0fc515d8c2 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -11,11 +11,18 @@ import os from pathlib import Path +import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def compute_fbank_musan(): src_dir = Path("data/manifests") diff --git a/icefall/decode.py b/icefall/decode.py index ed08405fa0..0e9baf2e46 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -555,24 +555,31 @@ def rescore_with_attention_decoder( model: nn.Module, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, + sos_id: int, + eos_id: int, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest score is used as the decoding output. - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - num_paths: - Number of paths to extract from the given lattice for rescoring. - model: - A transformer model. See the class "Transformer" in - conformer_ctc/transformer.py for its interface. - memory: - The encoder memory of the given model. It is the output of - the last torch.nn.TransformerEncoder layer in the given model. - Its shape is `[T, N, C]`. - memory_key_padding_mask: - The padding mask for memory with shape [N, T]. + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `[T, N, C]`. + memory_key_padding_mask: + The padding mask for memory with shape [N, T]. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -661,7 +668,11 @@ def rescore_with_attention_decoder( # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( - expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1 + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, ) assert nll.ndim == 2 assert nll.shape[0] == num_word_seqs From 2be7a0a55590ea40d7def5c513202bd6e5a18ab7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 3 Aug 2021 17:24:06 +0800 Subject: [PATCH 04/19] Remove unused code. --- .../ASR/conformer_ctc/conformer.py | 12 +- egs/librispeech/ASR/conformer_ctc/decode.py | 32 +- .../ASR/conformer_ctc/test_transformer.py | 55 +++- .../ASR/conformer_ctc/transformer.py | 284 +++++------------- 4 files changed, 161 insertions(+), 222 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index d3952d3b1a..a00664a992 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -84,7 +84,7 @@ def __init__( # and throws an error without this change. self.after_norm = identity - def encode( + def run_encoder( self, x: Tensor, supervisions: Optional[Supervisions] = None ) -> Tuple[Tensor, Optional[Tensor]]: """ @@ -802,7 +802,8 @@ def multi_head_attention_forward( bsz, num_heads, tgt_len, src_len ) attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len @@ -872,7 +873,12 @@ def __init__( ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( - channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, ) self.activation = Swish() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 9ebb76fa1d..0611814f69 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -15,6 +15,7 @@ import torch.nn as nn from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( @@ -85,7 +86,7 @@ def get_params() -> AttributeDict: # - whole-lattice-rescoring # - attention-decoder # "method": "whole-lattice-rescoring", - "method": "1best", + "method": "attention-decoder", # num_paths is used when method is "nbest", "nbest-rescoring", # and attention-decoder "num_paths": 100, @@ -100,6 +101,8 @@ def decode_one_batch( HLG: k2.Fsa, batch: dict, lexicon: Lexicon, + sos_id: int, + eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -133,6 +136,10 @@ def decode_one_batch( for the format of the `batch`. lexicon: It contains word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -222,8 +229,8 @@ def decode_one_batch( model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, - sos_id=lexicon.sos_id, - eos_id=lexicon.eos_id, + sos_id=sos_id, + eos_id=eos_id, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -242,6 +249,8 @@ def decode_dataset( model: nn.Module, HLG: k2.Fsa, lexicon: Lexicon, + sos_id: int, + eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[int], List[int]]]]: """Decode dataset. @@ -257,6 +266,10 @@ def decode_dataset( The decoding graph. lexicon: It contains word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -284,6 +297,8 @@ def decode_dataset( batch=batch, lexicon=lexicon, G=G, + sos_id=sos_id, + eos_id=eos_id, ) for lm_scale, hyps in hyps_dict.items(): @@ -364,6 +379,15 @@ def main(): logging.info(f"device: {device}") + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -456,6 +480,8 @@ def main(): HLG=HLG, lexicon=lexicon, G=G, + sos_id=sos_id, + eos_id=eos_id, ) save_results( diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py index a6569e8d76..08e6806074 100644 --- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -1,7 +1,16 @@ #!/usr/bin/env python3 import torch -from transformer import Transformer, encoder_padding_mask +from transformer import ( + Transformer, + encoder_padding_mask, + generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, +) + +from torch.nn.utils.rnn import pad_sequence def test_encoder_padding_mask(): @@ -34,3 +43,47 @@ def test_transformer(): x = torch.rand(N, T, num_features) y, _, _ = model(x) assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index b2123b8fcf..a974be4e02 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -10,6 +10,7 @@ from subsampling import Conv2dSubsampling, VggSubsampling from icefall.utils import get_texts +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -177,14 +178,17 @@ def forward( x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] - encoder_memory, memory_key_padding_mask = self.encode(x, supervision) - x = self.encoder_output(encoder_memory) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask - def encode( + def run_encoder( self, x: torch.Tensor, supervisions: Optional[Supervisions] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ + """Run the transformer encoder. + Args: x: The model input. Its shape is [N, T, C]. @@ -194,8 +198,8 @@ def encode( CAUTION: It contains length information, i.e., start and number of frames, before subsampling It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. Returns: Return a tuple with two tensors: - The encoder output, with shape [T, N, C] @@ -212,7 +216,7 @@ def encode( return x, mask - def encoder_output(self, x: torch.Tensor) -> torch.Tensor: + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: @@ -232,46 +236,16 @@ def decoder_forward( self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, - supervision: Optional[Supervisions] = None, - L_inv: Optional[k2.Fsa] = None, - word_table: Optional[k2.SymbolTable] = None, - oov_str: Optional[str] = None, - token_ids: List[List[int]] = None, - sos_id: Optional[int] = None, - eos_id: Optional[int] = None, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, ) -> torch.Tensor: """ - Note: - If phone based lexicon is used, the following arguments are required: - - - supervision - - L_inv - - word_table - - oov_str - - If BPE based lexicon is used, the following arguments are required: - - - token_ids - - sos_id - - eos_id - Args: memory: It's the output of the encoder with shape [T, N, C] memory_key_padding_mask: The padding mask from the encoder. - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - L_inv: - It is an FSA with labels being word IDs and aux_labels being - token IDs (e.g., phone IDs or word piece IDs). - word_table: - Word table providing mapping between words and IDs. - oov_str: - The OOV word, e.g., '' token_ids: A list-of-list IDs. Each sublist contains IDs for an utterance. The IDs can be either phone IDs or word piece IDs. @@ -284,29 +258,13 @@ def decoder_forward( A scalar, the **sum** of label smoothing loss over utterances in the batch without any normalization. """ - if supervision is not None and word_table is not None: - batch_text = get_normal_transcripts( - supervision, word_table, oov_str - ) - ys_in_pad, ys_out_pad = add_sos_eos( - batch_text, - L_inv, - sos_id, - eos_id, - ) - elif token_ids is not None: - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys_in = [ - torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids - ] - ys_out = [ - torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids - ] - ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_out, -1) - else: - raise ValueError("Invalid input for decoder self attention") + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -316,6 +274,8 @@ def decoder_forward( device ) + # TODO: Use eos_id as ignore_id. + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) @@ -362,19 +322,14 @@ def decoder_nll( """ # The common part between this function and decoder_forward could be # extracted as a separate function. - if token_ids is not None: - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys_in = [ - torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids - ] - ys_out = [ - torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids - ] - ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_out, -1) - else: - raise ValueError("Invalid input for decoder self attention") + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) @@ -384,6 +339,8 @@ def decoder_nll( device ) + # TODO: Use eos_id as ignore_id. + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) @@ -948,8 +905,8 @@ def decoder_padding_mask( ) -> torch.Tensor: """Generate a length mask for input. - The masked position are filled with bool(True), - Unmasked positions are filled with bool(False). + The masked position are filled with True, + Unmasked positions are filled with False. Args: ys_pad: @@ -965,45 +922,16 @@ def decoder_padding_mask( return ys_mask -def get_normal_transcripts( - supervision: Supervisions, words: k2.SymbolTable, oov: str = "" -) -> List[List[int]]: - """Get normal transcripts (1 input recording has 1 transcript) - from lhotse cut format. - - Achieved by concatenating the transcripts corresponding to the - same recording. - - Args: - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - words: - The word symbol table. - oov: - Out of vocabulary word. - - Returns: - List[List[int]]: List of concatenated transcripts, length is batch_size - """ - - texts = [ - [token if token in words else oov for token in text.split(" ")] - for text in supervision["text"] - ] - texts_ids = [[words[token] for token in text] for text in texts] - - batch_text = [ - [] for _ in range(int(supervision["sequence_idx"].max().item()) + 1) - ] - for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids): - batch_text[sequence_idx] = batch_text[sequence_idx] + text - return batch_text - - def generate_square_subsequent_mask(sz: int) -> torch.Tensor: """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) Args: sz: mask size @@ -1020,115 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor: return mask -def add_sos_eos( - ys: List[List[int]], - L_inv: k2.Fsa, - sos_id: int, - eos_id: int, - ignore_id: int = -1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Add and labels. +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. Args: - ys: - Batch of unpadded target sequences (i.e., word IDs) - L_inv: - Its labels are words, while its aux_labels are tokens. + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. sos_id: - index of - eos_id: - index of - ignore_id: - value for padding - - Returns: - Return a tuple containing two tensors: - - Input of transformer decoder. - Padded tensor of dimension (batch_size, max_length). - - Output of transformer decoder. - Padded tensor of dimension (batch_size, max_length). - """ - - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys = get_hierarchical_targets(ys, L_inv) - ys_in = [torch.cat([_sos, y], dim=0) for y in ys] - ys_out = [torch.cat([y, _eos], dim=0) for y in ys] - return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id) - - -def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor: - """Perform padding for the list of tensors. - - Args: - ys: - List of tensors. len(ys) = batch_size. - pad_value: - Value for padding. - - Returns: - Tensor: Padded tensor (batch_size, max_length, `*`). - - Examples: - >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) + The ID of the SOS token. + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. """ - n_batch = len(ys) - max_len = max(x.size(0) for x in ys) - pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value) - - for i in range(n_batch): - pad[i, : ys[i].size(0)] = ys[i] - - return pad + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans -def get_hierarchical_targets( - ys: List[List[int]], L_inv: Optional[k2.Fsa] = None -) -> List[torch.Tensor]: - """Get hierarchical transcripts (i.e., phone level transcripts) from - transcripts (i.e., word level transcripts). +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. Args: - ys: - Word level transcripts. Each sublist is a transcript of an utterance. - L_inv: - Its labels are words, while its aux_labels are tokens. + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. - Returns: - List[torch.Tensor]: - Token level transcripts. + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. """ - - if L_inv is None: - return [torch.tensor(y) for y in ys] - - device = L_inv.device - - transcripts = k2.create_fsa_vec( - [k2.linear_fsa(x, device=device) for x in ys] - ) - transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) - - transcripts_lexicon = k2.intersect( - L_inv, transcripts_with_self_loops, treat_epsilons_specially=False - ) - # Don't call invert_() above because we want to return phone IDs, - # which is the `aux_labels` of transcripts_lexicon - transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) - transcripts_lexicon = k2.top_sort(transcripts_lexicon) - - transcripts_lexicon = k2.shortest_path( - transcripts_lexicon, use_double_scores=True - ) - - ys = get_texts(transcripts_lexicon) - ys = [torch.tensor(y) for y in ys] - - return ys + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans From a6d9b3c9ab625154fe37ed26ca468a14751cecc7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 3 Aug 2021 22:16:34 +0800 Subject: [PATCH 05/19] Minor fixes. --- egs/librispeech/ASR/conformer_ctc/decode.py | 17 ++++++++++++++--- egs/librispeech/ASR/local/compile_hlg.py | 2 +- icefall/utils.py | 19 +++++++++++++------ 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 0611814f69..889a0a4744 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -326,20 +326,31 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index c02fb7c0db..b304021616 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Loading G_3_gram.fst.txt") with open("data/lm/G_3_gram.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "G_3_gram.pt") + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] diff --git a/icefall/utils.py b/icefall/utils.py index 1f2cf95f34..3d48badfef 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -225,7 +225,10 @@ def store_transcripts( def write_error_stats( - f: TextIO, test_set_name: str, results: List[Tuple[str, str]] + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -255,6 +258,9 @@ def write_error_stats( results: An iterable of tuples. The first element is the reference transcript while the second element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. Returns: Return None. """ @@ -290,11 +296,12 @@ def write_error_stats( tot_errs = sub_errs + ins_errs + del_errs tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) - logging.info( - f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " - f"[{tot_errs} / {ref_len}, {ins_errs} ins, " - f"{del_errs} del, {sub_errs} sub ]" - ) + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) print(f"%WER = {tot_err_rate}", file=f) print( From b1b21eb1e4d2d0079aa0b8ae104ccecf13e574f6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 4 Aug 2021 14:57:06 +0800 Subject: [PATCH 06/19] Fix decoder padding mask. --- .../ASR/conformer_ctc/transformer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index a974be4e02..2722e5ba6a 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -105,10 +105,7 @@ def __init__( norm=encoder_norm, ) - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) + self.encoder_output_layer = nn.Linear(d_model, num_classes) if num_decoder_layers > 0: if mmi_loss: @@ -274,9 +271,12 @@ def decoder_forward( device ) - # TODO: Use eos_id as ignore_id. - # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_pos(tgt) @@ -339,9 +339,9 @@ def decoder_nll( device ) - # TODO: Use eos_id as ignore_id. - # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + tgt_key_padding_mask[:, 0] = False + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_pos(tgt) From 897307f4454adab43ad0b88eef705a8efb530888 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 7 Aug 2021 16:41:16 +0800 Subject: [PATCH 07/19] Add MMI training with word pieces. --- egs/librispeech/ASR/conformer_mmi/__init__.py | 0 .../ASR/conformer_mmi/conformer.py | 918 ++++++++++++++++ egs/librispeech/ASR/conformer_mmi/decode.py | 507 +++++++++ .../ASR/conformer_mmi/subsampling.py | 144 +++ .../ASR/conformer_mmi/test_subsampling.py | 33 + .../ASR/conformer_mmi/test_transformer.py | 89 ++ egs/librispeech/ASR/conformer_mmi/train.py | 688 ++++++++++++ .../ASR/conformer_mmi/transformer.py | 976 ++++++++++++++++++ .../ASR/local/convert_transcript_to_corpus.py | 100 ++ .../ASR/local/ngram_entropy_pruning.py | 627 +++++++++++ egs/librispeech/ASR/prepare.sh | 70 +- icefall/bpe_graph_compiler.py | 10 +- icefall/bpe_mmi_graph_compiler.py | 178 ++++ icefall/lexicon.py | 6 +- icefall/mmi.py | 222 ++++ icefall/shared/make_kn_lm.py | 377 +++++++ test/test_bpe_mmi_graph_compiler.py | 30 + 17 files changed, 4968 insertions(+), 7 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_mmi/__init__.py create mode 100644 egs/librispeech/ASR/conformer_mmi/conformer.py create mode 100755 egs/librispeech/ASR/conformer_mmi/decode.py create mode 100644 egs/librispeech/ASR/conformer_mmi/subsampling.py create mode 100755 egs/librispeech/ASR/conformer_mmi/test_subsampling.py create mode 100644 egs/librispeech/ASR/conformer_mmi/test_transformer.py create mode 100755 egs/librispeech/ASR/conformer_mmi/train.py create mode 100644 egs/librispeech/ASR/conformer_mmi/transformer.py create mode 100755 egs/librispeech/ASR/local/convert_transcript_to_corpus.py create mode 100644 egs/librispeech/ASR/local/ngram_entropy_pruning.py create mode 100644 icefall/bpe_mmi_graph_compiler.py create mode 100644 icefall/mmi.py create mode 100755 icefall/shared/make_kn_lm.py create mode 100644 test/test_bpe_mmi_graph_compiler.py diff --git a/egs/librispeech/ASR/conformer_mmi/__init__.py b/egs/librispeech/ASR/conformer_mmi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py new file mode 100644 index 0000000000..ac49b7b1c4 --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -0,0 +1,918 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + is_espnet_structure: bool = False, + use_feat_batchnorm: bool = False, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + is_espnet_structure, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + self.is_espnet_structure = is_espnet_structure + if self.normalize_before and self.is_espnet_structure: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + if mask is not None: + mask = mask.to(x.device) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before and self.is_espnet_structure: + x = self.after_norm(x) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + is_espnet_structure: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_espnet_structure: bool = False, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + self.is_espnet_structure = is_espnet_structure + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if not self.is_espnet_structure: + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + if not self.is_espnet_structure: + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) + else: + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py new file mode 100755 index 0000000000..6030d13e1b --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) + +# (still working in progress) + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from conformer import Conformer + +from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.decode import ( + get_lattice, + nbest_decoding, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=9, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_mmi/exp"), + "lang_dir": Path("data/lang_bpe"), + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "use_feat_batchnorm": True, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + # Possible values for method: + # - 1best + # - nbest + # - nbest-rescoring + # - whole-lattice-rescoring + # - attention-decoder + # "method": "whole-lattice-rescoring", + "method": "1best", + # num_paths is used when method is "nbest", "nbest-rescoring", + # and attention-decoder + "num_paths": 100, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + batch: dict, + lexicon: Lexicon, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + lexicon: + It contains word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = HLG.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is [N, T, C] + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + ) + key = f"no_rescore-{params.num_paths}" + + hyps = get_texts(best_path) + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] + + lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + lexicon: Lexicon, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[int], List[int]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. + lexicon: + It contains word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + tot_num_cuts = len(dl.dataset.cuts) + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + batch=batch, + lexicon=lexicon, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + logging.info( + f"batch {batch_idx}, cuts processed until now is " + f"{num_cuts}/{tot_num_cuts} " + f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeMmiTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) + HLG = HLG.to(device) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt") + G = k2.Fsa.from_dict(d).to(device) + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + is_espnet_structure=params.is_espnet_structure, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + # CAUTION: `test_sets` is for displaying only. + # If you want to skip test-clean, you have to skip + # it inside the for loop. That is, use + # + # if test_set == 'test-clean': continue + # + test_sets = ["test-clean", "test-other"] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + lexicon=lexicon, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py new file mode 100644 index 0000000000..5c3e1222ef --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + # On entry, x is [N, T, idim] + x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + x = self.conv(x) + # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py new file mode 100755 index 0000000000..937845d779 --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling +import torch + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py new file mode 100644 index 0000000000..08e6806074 --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +import torch +from transformer import ( + Transformer, + encoder_padding_mask, + generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, +) + +from torch.nn.utils.rnn import pad_sequence + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py new file mode 100755 index 0000000000..810a0a4dfe --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -0,0 +1,688 @@ +#!/usr/bin/env python3 + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional + +import k2 +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from conformer import Conformer +from lhotse.utils import fix_random_seed +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.dist import cleanup_dist, setup_dist +from icefall.lexicon import Lexicon +from icefall.mmi import LFMMILoss +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + # TODO: add extra arguments and support DDP training. + # Currently, only single GPU training is implemented. Will add + # DDP training once single GPU training is finished. + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - exp_dir: It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + + - lang_dir: It contains language related input files such as + "lexicon.txt" + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - num_epochs: Number of epochs to train. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + """ + params = AttributeDict( + { + "exp_dir": Path("conformer_mmi/exp"), + "lang_dir": Path("data/lang_bpe"), + "feature_dim": 80, + "weight_decay": 1e-6, + "subsampling_factor": 4, + "start_epoch": 0, + "num_epochs": 10, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + # It takes about 10 minutes (1 GPU, max_duration=200) + # to run a validation process. + # For the 100 h subset, there are 85617 batches. + # For the 960 h dataset, there are 843723 batches + "valid_interval": 8000, + "use_pruned_intersect": False, + "den_scale": 1.0, + # + "att_rate": 0.7, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + "is_espnet_structure": True, + "use_feat_batchnorm": True, + "lr_factor": 5.0, + "warm_step": 80000, + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeMmiTrainingGraphCompiler, + is_training: bool, +): + """ + Compute MMI loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build num_graphs and den_graphs. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in LFMMILoss + # + # TODO: If params.use_pruned_intersect is True, there is no + # need to call encode_supervisions + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + den_scale=params.den_scale, + use_pruned_intersect=params.use_pruned_intersect, + ) + + mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) + + if params.att_rate != 0.0: + token_ids = graph_compiler.texts_to_ids(texts) + with torch.set_grad_enabled(is_training): + if hasattr(model, "module"): + att_loss = model.module.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss + else: + loss = mmi_loss + att_loss = torch.tensor([0]) + + # train_frames and valid_frames are used for printing. + if is_training: + params.train_frames = supervision_segments[:, 2].sum().item() + else: + params.valid_frames = supervision_segments[:, 2].sum().item() + + assert loss.requires_grad == is_training + + return loss, mmi_loss.detach(), att_loss.detach() + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeMmiTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> None: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = 0.0 + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + tot_frames = 0.0 + for batch_idx, batch in enumerate(valid_dl): + loss, mmi_loss, att_loss = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + assert mmi_loss.requires_grad is False + assert att_loss.requires_grad is False + + loss_cpu = loss.detach().cpu().item() + tot_loss += loss_cpu + + tot_mmi_loss += mmi_loss.detach().cpu().item() + tot_att_loss += att_loss.detach().cpu().item() + + tot_frames += params.valid_frames + + if world_size > 1: + s = torch.tensor( + [tot_loss, tot_mmi_loss, tot_att_loss, tot_frames], + device=loss.device, + ) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + s = s.cpu().tolist() + tot_loss = s[0] + tot_mmi_loss = s[1] + tot_att_loss = s[2] + tot_frames = s[3] + + params.valid_loss = tot_loss / tot_frames + params.valid_mmi_loss = tot_mmi_loss / tot_frames + params.valid_att_loss = tot_att_loss / tot_frames + + if params.valid_loss < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = params.valid_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeMmiTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = 0.0 # sum of losses over all batches + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, mmi_loss, att_loss = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + loss_cpu = loss.detach().cpu().item() + mmi_loss_cpu = mmi_loss.detach().cpu().item() + att_loss_cpu = att_loss.detach().cpu().item() + + tot_frames += params.train_frames + tot_loss += loss_cpu + tot_mmi_loss += mmi_loss_cpu + tot_att_loss += att_loss_cpu + + tot_avg_loss = tot_loss / tot_frames + tot_avg_mmi_loss = tot_mmi_loss / tot_frames + tot_avg_att_loss = tot_att_loss / tot_frames + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, " + f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " + f"batch avg loss {loss_cpu/params.train_frames:.4f}, " + f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, " + f"total avg att loss: {tot_avg_att_loss:.4f}, " + f"total avg loss: {tot_avg_loss:.4f}, " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_mmi_loss", + mmi_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_att_loss", + att_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_mmi_loss", + tot_avg_mmi_loss, + params.batch_idx_train, + ) + + tb_writer.add_scalar( + "train/tot_avg_att_loss", + tot_avg_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"valid mmi loss {params.valid_mmi_loss:.4f}, " + f"valid att loss {params.valid_att_loss:.4f}, " + f"valid loss {params.valid_loss:.4f}, " + f"best valid loss: {params.best_valid_loss:.4f}, " + f"best valid epoch: {params.best_valid_epoch}" + ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_mmi_loss", + params.valid_mmi_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_att_loss", + params.valid_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) + + params.train_loss = tot_loss / tot_frames + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeMmiTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + is_espnet_structure=params.is_espnet_structure, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py new file mode 100644 index 0000000000..fd1a082e7c --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -0,0 +1,976 @@ +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import math +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from subsampling import Conv2dSubsampling, VggSubsampling + +from icefall.utils import get_texts +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape [N, T, num_classes] + # to the shape [N, T//subsampling_factor, d_model]. + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + self.encoder_output_layer = nn.Linear(d_model, num_classes) + + if num_decoder_layers > 0: + self.decoder_num_class = self.num_classes + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) + else: + self.decoder_criterion = None + + def forward( + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is [N, T, C]. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is [N, T, C] + - Encoder output with shape [T, N, C]. It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is [N, T]. + It is None if `supervision` is None. + """ + if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape [T, N, C] + - encoder padding mask, with shape [N, T]. + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is [T, N, C] + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is [N, T, C] + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + self.pe = None + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is [N, T, C] + + Returns: + Return a tensor of shape [N, T, C] + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +class LabelSmoothingLoss(nn.Module): + """ + Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa + + Args: + size: the number of class + padding_idx: padding_idx: ignored class id + smoothing: smoothing rate (0.0 means the conventional CE) + normalize_length: normalize loss by sequence length if True + criterion: loss function to be smoothed + """ + + def __init__( + self, + size: int, + padding_idx: int = -1, + smoothing: float = 0.1, + normalize_length: bool = False, + criterion: nn.Module = nn.KLDivLoss(reduction="none"), + ) -> None: + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + assert 0.0 < smoothing <= 1.0 + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.size(2) == self.size + # batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + # denom = total if self.normalize_length else batch_size + denom = total if self.normalize_length else 1 + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans diff --git a/egs/librispeech/ASR/local/convert_transcript_to_corpus.py b/egs/librispeech/ASR/local/convert_transcript_to_corpus.py new file mode 100755 index 0000000000..bb02dac581 --- /dev/null +++ b/egs/librispeech/ASR/local/convert_transcript_to_corpus.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears last in the lexicon +is used. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o + hello h e l l o 2 + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +from pathlib import Path +from typing import Dict + +import argparse + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument( + "--oov", type=str, default="", help="The OOV word." + ) + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + lexicon = dict(read_lexicon(args.lexicon)) + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/local/ngram_entropy_pruning.py b/egs/librispeech/ASR/local/ngram_entropy_pruning.py new file mode 100644 index 0000000000..d0ffa92f6f --- /dev/null +++ b/egs/librispeech/ASR/local/ngram_entropy_pruning.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) +# Apache 2.0. + +# This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' +# in the same way as SRILM. + +################################################ +# Useful links/References: +################################################ +# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 +# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2124 +# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 +# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 +# https://github.com/sfischer13/python-arpa + +################################################ +# How to use: +################################################ +# python3 ngram_entropy_pruning.py -threshold $threshold -lm $input_lm -write-lm $pruned_lm + +################################################ +# SRILM commands: +################################################ +# to_prune_lm=egs/swbd/s5c/data/local/lm/sw1.o3g.kn.gz +# vocab=egs/swbd/s5c/data/local/lm/wordlist +# order=3 +# oov_symbol="" +# threshold=4.7e-5 +# pruned_lm=temp.${threshold}.gz +# ngram -unk -map-unk "$oov_symbol" -vocab $vocab -order $order -prune ${threshold} -lm ${to_prune_lm} -write-lm ${pruned_lm} +# +# lm= +# ngram -unk -lm $lm -ppl heldout +# ngram -unk -lm $lm -ppl heldout -debug 3 + +import argparse +import logging +import math + +import gzip +from io import StringIO +from collections import OrderedDict +from collections import defaultdict +from enum import Enum, unique +import re + +parser = argparse.ArgumentParser(description=""" + Prune an n-gram language model based on the relative entropy + between the original and the pruned model, based on Andreas Stolcke's paper. + An n-gram entry is removed, if the removal causes (training set) perplexity + of the model to increase by less than threshold relative. + + The command takes an arpa file and a pruning threshold as input, + and outputs a pruned arpa file. + """) +parser.add_argument("-threshold", + type=float, + default=1e-6, + help="Order of n-gram") +parser.add_argument("-lm", + type=str, + default=None, + help="Path to the input arpa file") +parser.add_argument("-write-lm", + type=str, + default=None, + help="Path to output arpa file after pruning") +parser.add_argument("-minorder", + type=int, + default=1, + help="The minorder parameter limits pruning to " + "ngrams of that length and above.") +parser.add_argument("-encoding", + type=str, + default="utf-8", + help="Encoding of the arpa file") +parser.add_argument("-verbose", + type=int, + default=2, + choices=[0, 1, 2, 3, 4, 5], + help="Verbose level, where " + "0 is most noisy; " + "5 is most silent") +args = parser.parse_args() + +default_encoding = args.encoding +logging.basicConfig( + format= + "%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", + level=args.verbose * 10) + + +class Context(dict): + """ + This class stores data for a context h. + It behaves like a python dict object, except that it has several + additional attributes. + """ + def __init__(self): + super().__init__() + self.log_bo = None + + +class Arpa: + """ + This is a class that implement the data structure of an APRA LM. + It (as well as some other classes) is modified based on the library + by Stefan Fischer: + https://github.com/sfischer13/python-arpa + """ + + UNK = '' + SOS = '' + EOS = '' + FLOAT_NDIGITS = 7 + base = 10 + + @staticmethod + def _check_input(my_input): + if not my_input: + raise ValueError + elif isinstance(my_input, tuple): + return my_input + elif isinstance(my_input, list): + return tuple(my_input) + elif isinstance(my_input, str): + return tuple(my_input.strip().split(' ')) + else: + raise ValueError + + @staticmethod + def _check_word(input_word): + if not isinstance(input_word, str): + raise ValueError + if ' ' in input_word: + raise ValueError + + def _replace_unks(self, words): + return tuple((w if w in self else self._unk) for w in words) + + def __init__(self, path=None, encoding=None, unk=None): + self._counts = OrderedDict() + self._ngrams = OrderedDict( + ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) + self._vocabulary = set() + if unk is None: + self._unk = self.UNK + + if path is not None: + self.loadf(path, encoding) + + def __contains__(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] + + def contains_word(self, word): + self._check_word(word) + return word in self._vocabulary + + def add_count(self, order, count): + self._counts[order] = count + self._ngrams[order - 1] = defaultdict(Context) + + def update_counts(self): + for order in range(1, self.order() + 1): + count = sum( + [len(wlist) for _, wlist in self._ngrams[order - 1].items()]) + if count > 0: + self._counts[order] = count + + def add_entry(self, ngram, p, bo=None, order=None): + # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + + # Note that p and bo here are in fact in the log domain (self.base = 10) + h_context = self._ngrams[len(h)][h] + h_context[w] = p + if bo is not None: + self._ngrams[len(ngram)][ngram].log_bo = bo + + for word in ngram: + self._vocabulary.add(word) + + def counts(self): + return sorted(self._counts.items()) + + def order(self): + return max(self._counts.keys(), default=None) + + def vocabulary(self, sort=True): + if sort: + return sorted(self._vocabulary) + else: + return self._vocabulary + + def _entries(self, order): + return (self._entry(h, w) + for h, wlist in self._ngrams[order - 1].items() for w in wlist) + + def _entry(self, h, w): + # return the entry for the ngram (h, w) + ngram = h + (w, ) + log_p = self._ngrams[len(h)][h][w] + log_bo = self._log_bo(ngram) + if log_bo is not None: + return round(log_p, self.FLOAT_NDIGITS), ngram, round( + log_bo, self.FLOAT_NDIGITS) + else: + return round(log_p, self.FLOAT_NDIGITS), ngram + + def _log_bo(self, ngram): + if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: + return self._ngrams[len(ngram)][ngram].log_bo + else: + return None + + def _log_p(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: + return self._ngrams[len(h)][h][w] + else: + return None + + def log_p_raw(self, ngram): + log_p = self._log_p(ngram) + if log_p is not None: + return log_p + else: + if len(ngram) == 1: + raise KeyError + else: + log_bo = self._log_bo(ngram[:-1]) + if log_bo is None: + log_bo = 0 + return log_bo + self.log_p_raw(ngram[1:]) + + def log_joint_prob(self, sequence): + # Compute the joint prob of the sequence based on the chain rule + # Note that sequence should be a tuple of strings + # + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 + + log_joint_p = 0 + seq = sequence + while len(seq) > 0: + log_joint_p += self.log_p_raw(seq) + seq = seq[:-1] + + # If we're computing the marginal probability of the unigram + # context we have to look up instead since the former + # has prob = 0. + if len(seq) == 1 and seq[0] == self.SOS: + seq = (self.EOS, ) + + return log_joint_p + + def set_new_context(self, h): + old_context = self._ngrams[len(h)][h] + self._ngrams[len(h)][h] = Context() + return old_context + + def log_p(self, ngram): + words = self._check_input(ngram) + if self._unk: + words = self._replace_unks(words) + return self.log_p_raw(words) + + def log_s(self, sentence, sos=SOS, eos=EOS): + words = self._check_input(sentence) + if self._unk: + words = self._replace_unks(words) + if sos: + words = (sos, ) + words + if eos: + words = words + (eos, ) + result = sum( + self.log_p_raw(words[:i]) for i in range(1, + len(words) + 1)) + if sos: + result = result - self.log_p_raw(words[:1]) + return result + + def p(self, ngram): + return self.base**self.log_p(ngram) + + def s(self, sentence): + return self.base**self.log_s(sentence) + + def write(self, fp): + fp.write('\n\\data\\\n') + for order, count in self.counts(): + fp.write('ngram {}={}\n'.format(order, count)) + fp.write('\n') + for order, _ in self.counts(): + fp.write('\\{}-grams:\n'.format(order)) + for e in self._entries(order): + prob = e[0] + ngram = ' '.join(e[1]) + if len(e) == 2: + fp.write('{}\t{}\n'.format(prob, ngram)) + elif len(e) == 3: + backoff = e[2] + fp.write('{}\t{}\t{}\n'.format(prob, ngram, backoff)) + else: + raise ValueError + fp.write('\n') + fp.write('\\end\\\n') + + +class ArpaParser: + """ + This is a class that implement a parser of an arpa file + """ + @unique + class State(Enum): + DATA = 1 + COUNT = 2 + HEADER = 3 + ENTRY = 4 + + re_count = re.compile(r'^ngram (\d+)=(\d+)$') + re_header = re.compile(r'^\\(\d+)-grams:$') + re_entry = re.compile('^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)' + '\t' + '(\\S+( \\S+)*)' + '(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$') + + def _parse(self, fp): + self._result = [] + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + for line in fp: + line = line.strip() + if self._state == self.State.DATA: + self._data(line) + elif self._state == self.State.COUNT: + self._count(line) + elif self._state == self.State.HEADER: + self._header(line) + elif self._state == self.State.ENTRY: + self._entry(line) + if self._state != self.State.DATA: + raise Exception(line) + return self._result + + def _data(self, line): + if line == '\\data\\': + self._state = self.State.COUNT + self._tmp_model = Arpa() + else: + pass # skip comment line + + def _count(self, line): + match = self.re_count.match(line) + if match: + order = match.group(1) + count = match.group(2) + self._tmp_model.add_count(int(order), int(count)) + elif not line: + self._state = self.State.HEADER # there are no counts + else: + raise Exception(line) + + def _header(self, line): + match = self.re_header.match(line) + if match: + self._state = self.State.ENTRY + self._tmp_order = int(match.group(1)) + elif line == '\\end\\': + self._result.append(self._tmp_model) + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + elif not line: + pass # skip empty line + else: + raise Exception(line) + + def _entry(self, line): + match = self.re_entry.match(line) + if match: + p = self._float_or_int(match.group(1)) + ngram = tuple(match.group(4).split(' ')) + bo_match = match.group(7) + bo = self._float_or_int(bo_match) if bo_match else None + self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) + elif not line: + self._state = self.State.HEADER # last entry + else: + raise Exception(line) + + @staticmethod + def _float_or_int(s): + f = float(s) + i = int(f) + if str(i) == s: # don't drop trailing ".0" + return i + else: + return f + + def load(self, fp): + """Deserialize fp (a file-like object) to a Python object.""" + return self._parse(fp) + + def loadf(self, path, encoding=None): + """Deserialize path (.arpa, .gz) to a Python object.""" + path = str(path) + if path.endswith('.gz'): + with gzip.open(path, mode='rt', encoding=encoding) as f: + return self.load(f) + else: + with open(path, mode='rt', encoding=encoding) as f: + return self.load(f) + + def loads(self, s): + """Deserialize s (a str) to a Python object.""" + with StringIO(s) as f: + return self.load(f) + + def dump(self, obj, fp): + """Serialize obj to fp (a file-like object) in ARPA format.""" + obj.write(fp) + + def dumpf(self, obj, path, encoding=None): + """Serialize obj to path in ARPA format (.arpa, .gz).""" + path = str(path) + if path.endswith('.gz'): + with gzip.open(path, mode='wt', encoding=encoding) as f: + return self.dump(obj, f) + else: + with open(path, mode='wt', encoding=encoding) as f: + self.dump(obj, f) + + def dumps(self, obj): + """Serialize obj to an ARPA formatted str.""" + with StringIO() as f: + self.dump(obj, f) + return f.getvalue() + + +def add_log_p(prev_log_sum, log_p, base): + return math.log(base**log_p + base**prev_log_sum, base) + + +def compute_numerator_denominator(lm, h): + log_sum_seen_h = -math.inf + log_sum_seen_h_lower = -math.inf + base = lm.base + for w, log_p in lm._ngrams[len(h)][h].items(): + log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) + + ngram = h + (w, ) + log_p_lower = lm.log_p_raw(ngram[1:]) + log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, + base) + + numerator = 1.0 - base**log_sum_seen_h + denominator = 1.0 - base**log_sum_seen_h_lower + return numerator, denominator + + +def prune(lm, threshold, minorder): + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 + + for i in range(lm.order(), max(minorder - 1, 1), + -1): # i is the order of the ngram (h, w) + logging.info("processing %d-grams ..." % i) + count_pruned_ngrams = 0 + + h_dict = lm._ngrams[i - 1] + for h in list(h_dict.keys()): + # old backoff weight, BOW(h) + log_bow = lm._log_bo(h) + if log_bow is None: + log_bow = 0 + + # Compute numerator and denominator of the backoff weight, + # so that we can quickly compute the BOW adjustment due to + # leaving out one prob. + numerator, denominator = compute_numerator_denominator(lm, h) + + # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 + + # Compute the marginal probability of the context, P(h) + h_log_p = lm.log_joint_prob(h) + + all_pruned = True + pruned_w_set = set() + + for w, log_p in h_dict[h].items(): + ngram = h + (w, ) + + # lower-order estimate for ngramProb, P(w|h') + backoff_prob = lm.log_p_raw(ngram[1:]) + + # Compute BOW after removing ngram, BOW'(h) + new_log_bow = math.log(numerator + lm.base ** log_p, lm.base) - \ + math.log(denominator + lm.base ** backoff_prob, lm.base) + + # Compute change in entropy due to removal of ngram + delta_prob = backoff_prob + new_log_bow - log_p + delta_entropy = - (lm.base ** h_log_p) * \ + ((lm.base ** log_p) * delta_prob + + numerator * (new_log_bow - log_bow)) + + # compute relative change in model (training set) perplexity + perp_change = lm.base**delta_entropy - 1.0 + + pruned = threshold > 0 and perp_change < threshold + + # Make sure we don't prune ngrams whose backoff nodes are needed + if pruned and \ + len(ngram) in lm._ngrams and \ + len(lm._ngrams[len(ngram)][ngram]) > 0: + pruned = False + + logging.debug("CONTEXT " + str(h) + " WORD " + w + + " CONTEXTPROB %f " % h_log_p + + " OLDPROB %f " % log_p + " NEWPROB %f " % + (backoff_prob + new_log_bow) + + " DELTA-H %f " % delta_entropy + + " DELTA-LOGP %f " % delta_prob + + " PPL-CHANGE %f " % perp_change + " PRUNED " + + str(pruned)) + + if pruned: + pruned_w_set.add(w) + count_pruned_ngrams += 1 + else: + all_pruned = False + + # If we removed all ngrams for this context we can + # remove the context itself, but only if the present + # context is not a prefix to a longer one. + if all_pruned and len(pruned_w_set) == len(h_dict[h]): + del h_dict[ + h] # this context h is no longer needed, as its ngram prob is stored at its own context h' + elif len(pruned_w_set) > 0: + # The pruning for this context h is actually done here + old_context = lm.set_new_context(h) + + for w, p_w in old_context.items(): + if w not in pruned_w_set: + lm.add_entry( + h + (w, ), + p_w) # the entry hw is stored at the context h + + # We need to recompute the back-off weight, but + # this can only be done after completing the pruning + # of the lower-order ngrams. + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 + + logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) + + # recompute backoff weights + for i in range(max(minorder - 1, 1) + 1, + lm.order() + + 1): # be careful of this order: from low- to high-order + for h in lm._ngrams[i - 1]: + numerator, denominator = compute_numerator_denominator(lm, h) + new_log_bow = math.log(numerator, lm.base) - math.log( + denominator, lm.base) + lm._ngrams[len(h)][h].log_bo = new_log_bow + + # update counts + lm.update_counts() + + return + + +def check_h_is_valid(lm, h): + sum_under_h = sum( + [lm.base**lm.log_p_raw(h + (w, )) for w in lm.vocabulary(sort=False)]) + if abs(sum_under_h - 1.0) > 1e-6: + logging.info("warning: %s %f" % (str(h), sum_under_h)) + return False + else: + return True + + +def validate_lm(lm): + # sanity check if the conditional probability sums to one under each context h + for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) + logging.info("validating %d-grams ..." % i) + h_dict = lm._ngrams[i - 1] + for h in h_dict.keys(): + check_h_is_valid(lm, h) + + +def compare_two_apras(path1, path2): + pass + + +if __name__ == '__main__': + # load an arpa file + logging.info("Loading the arpa file from %s" % args.lm) + parser = ArpaParser() + models = parser.loadf(args.lm, encoding=default_encoding) + lm = models[0] # ARPA files may contain several models. + logging.info("Stats before pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + + # prune it, the language model will be modified in-place + logging.info("Start pruning the model with threshold=%.3E..." % + args.threshold) + prune(lm, args.threshold, args.minorder) + + # validate_lm(lm) + + # write the arpa language model to a file + logging.info("Stats after pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + logging.info("Saving the pruned arpa file to %s" % args.write_lm) + parser.dumpf(lm, args.write_lm, encoding=default_encoding) + logging.info("Done.") diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index ae676b199b..375da0d797 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -143,7 +143,71 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare G" + log "Stage 7: Prepare bigram P" + if [ ! -f data/lang_bpe/corpus.txt ]; then + ./local/convert_transcript_to_corpus.py \ + --lexicon data/lang_bpe/lexicon.txt \ + --transcript data/lang_bpe/train.txt \ + --oov "" \ + > data/lang_bpe/corpus.txt + fi + + if [ ! -f data/lang_bpe/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text data/lang_bpe/corpus.txt \ + -lm data/lang_bpe/P.arpa + fi + + # TODO: Use egs/wsj/s5/utils/lang/ngram_entropy_pruning.py + # from kaldi to prune P if it causes OOM later + + if [ ! -f data/lang_bpe/P-no-prune.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="data/lang_bpe/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + data/lang_bpe/P.arpa > data/lang_bpe/P-no-prune.fst.txt + fi + + thresholds=( + 1e-6 + 1e-7 + ) + for threshold in ${thresholds[@]}; do + if [ ! -f data/lang_bpe/P-pruned.${threshold}.arpa ]; then + python3 ./local/ngram_entropy_pruning.py \ + -threshold $threshold \ + -lm data/lang_bpe/P.arpa \ + -write-lm data/lang_bpe/P-pruned.${threshold}.arpa + fi + + if [ ! -f data/lang_bpe/P-pruned.${threshold}.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="data/lang_bpe/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + data/lang_bpe/P-pruned.${threshold}.arpa > data/lang_bpe/P-pruned.${threshold}.fst.txt + fi + done + + if [ ! -f data/lang_bpe/P-uni.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="data/lang_bpe/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=1 \ + data/lang_bpe/P.arpa > data/lang_bpe/P-uni.fst.txt + fi + + ( cd data/lang_bpe; + # ln -sfv P-pruned.1e-6.fst.txt P.fst.txt + ln -sfv P-no-prune.fst.txt P.fst.txt + ) + rm -fv data/lang_bpe/P.pt data/lang_bpe/ctc_topo_P.pt +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm @@ -167,7 +231,7 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then fi fi -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compile HLG" +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" python3 ./local/compile_hlg.py fi diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py index e22cf4edcb..c28de42bfc 100644 --- a/icefall/bpe_graph_compiler.py +++ b/icefall/bpe_graph_compiler.py @@ -17,10 +17,14 @@ def __init__( """ Args: lang_dir: - This directory is expected to contain the following files: + This directory is expected to contain the following files:: - bpe.model - words.txt + + The above files are produced by the script `prepare.sh`. You + should have run that before running the training code. + device: It indicates CPU or CUDA. sos_token: @@ -57,7 +61,9 @@ def texts_to_ids(self, texts: List[str]) -> List[List[int]]: return self.sp.encode(texts, out_type=int) def compile( - self, piece_ids: List[List[int]], modified: bool = False, + self, + piece_ids: List[List[int]], + modified: bool = False, ) -> k2.Fsa: """Build a ctc graph from a list-of-list piece IDs. diff --git a/icefall/bpe_mmi_graph_compiler.py b/icefall/bpe_mmi_graph_compiler.py new file mode 100644 index 0000000000..83bc9846f9 --- /dev/null +++ b/icefall/bpe_mmi_graph_compiler.py @@ -0,0 +1,178 @@ +import logging +from pathlib import Path +from typing import List, Tuple, Union + +import k2 +import sentencepiece as spm +import torch + +from icefall.lexicon import Lexicon + + +class BpeMmiTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + device: Union[str, torch.device] = "cpu", + sos_token: str = "", + eos_token: str = "", + ) -> None: + """ + Args: + lang_dir: + Path to the lang directory. It is expected to contain the + following files:: + + - tokens.txt + - words.txt + - bpe.model + - P.fst.txt + + The above files are generated by the script `prepare.sh`. You + should have run it before running the training code. + + device: + It indicates CPU or CUDA. + sos_token: + The word piece that represents sos. + eos_token: + The word piece that represents eos. + """ + self.lang_dir = Path(lang_dir) + self.lexicon = Lexicon(lang_dir) + self.device = device + self.load_sentence_piece_model() + self.build_ctc_topo_P() + + self.sos_id = self.sp.piece_to_id(sos_token) + self.eos_id = self.sp.piece_to_id(eos_token) + + assert self.sos_id != self.sp.unk_id() + assert self.eos_id != self.sp.unk_id() + + def load_sentence_piece_model(self) -> None: + """Load the pre-trained sentencepiece model + from self.lang_dir/bpe.model. + """ + model_file = self.lang_dir / "bpe.model" + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + self.sp = sp + + def build_ctc_topo_P(self): + """Built ctc_topo_P, the composition result of + ctc_topo and P, where P is a pre-trained bigram + word piece LM. + """ + # Note: there is no need to save a pre-compiled P and ctc_topo + # as it is very fast to generate them. + logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}") + with open(self.lang_dir / "P.fst.txt") as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label 0 (i.e., ). + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + + first_token_disambig_id = self.lexicon.token_table["#0"] + + # P.aux_labels is not needed in later computations, so + # remove it here. + del P.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + P.labels[P.labels >= first_token_disambig_id] = 0 + + P = k2.remove_epsilon(P) + P = k2.arc_sort(P) + P = P.to(self.device) + # Add epsilon self-loops to P because we want the + # following operation "k2.intersect" to run on GPU. + P_with_self_loops = k2.add_epsilon_self_loops(P) + + max_token_id = max(self.lexicon.tokens) + logging.info( + f"Building modified ctc_topo. max_token_id: {max_token_id}" + ) + # CAUTION: We have to use a modifed version of CTC topo. + # Otherwise, the resulting ctc_topo_P is so large that it gets + # stuck in k2.intersect_dense_pruned() or it gets OOM in + # k2.intersect_dense() + ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device) + + ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + + logging.info("Building ctc_topo_P") + ctc_topo_P = k2.intersect( + ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False + ).invert() + + self.ctc_topo_P = k2.arc_sort(ctc_topo_P) + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + A list of transcripts. Within a transcript words are + separated by spaces. An example input is:: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of piece IDs. + """ + return self.sp.encode(texts, out_type=int) + + def compile( + self, texts: List[str], replicate_den: bool = True + ) -> Tuple[k2.Fsa, k2.Fsa]: + """Create numerator and denominator graphs from transcripts. + + Args: + texts: + A list of transcripts. Within a transcript words are + separated by spaces. An example input is:: + + ["HELLO icefall", "HALLO WELT"] + + replicate_den: + If True, the returned den_graph is replicated to match the number + of FSAs in the returned num_graph; if False, the returned den_graph + contains only a single FSA + Returns: + A tuple (num_graphs, den_graphs), where + + - `num_graphs` is the numerator graph. It is an FsaVec with + shape `(len(texts), None, None)`. + + - `den_graphs` is the denominator graph. It is an FsaVec with the + same shape of the `num_graph` if replicate_den is True; + otherwise, it is an FsaVec containing only a single FSA. + """ + token_ids = self.texts_to_ids(texts) + token_fsas = k2.linear_fsa(token_ids, device=self.device) + + token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas) + + # NOTE: Use treat_epsilons_specially=False so that k2.compose + # can be run on GPU + num_graphs = k2.compose( + self.ctc_topo_P, + token_fsas_with_self_loops, + treat_epsilons_specially=False, + ) + # num_graphs may not be connected and + # not be topologically sorted after k2.compose + num_graphs = k2.connect(num_graphs) + num_graphs = k2.top_sort(num_graphs) + + ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()]) + if replicate_den: + indexes = torch.zeros( + len(texts), dtype=torch.int32, device=self.device + ) + den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes) + else: + den_graphs = ctc_topo_P_vec + + return num_graphs, den_graphs diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 89747b11b0..43a0fda37a 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -78,11 +78,13 @@ def __init__( """ Args: lang_dir: - Path to the lang director. It is expected to contain the following - files: + Path to the lang directory. It is expected to contain the following + files:: + - tokens.txt - words.txt - L.pt + The above files are produced by the script `prepare.sh`. You should have run that before running the training code. disambig_pattern: diff --git a/icefall/mmi.py b/icefall/mmi.py new file mode 100644 index 0000000000..ec5d07dfeb --- /dev/null +++ b/icefall/mmi.py @@ -0,0 +1,222 @@ +from typing import List + +import k2 +import torch +from torch import nn + +from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler + + +def _compute_mmi_loss_exact_optimized( + dense_fsa_vec: k2.DenseFsaVec, + texts: List[str], + graph_compiler: BpeMmiTrainingGraphCompiler, + den_scale: float = 1.0, +) -> torch.Tensor: + """ + The function name contains `exact`, which means it uses a version of + intersection without pruning. + + `optimized` in the function name means this function is optimized + in that it calls k2.intersect_dense only once + + Note: + It is faster at the cost of using more memory. + + Args: + dense_fsa_vec: + It contains the neural network output. + texts: + The transcript. Each element consists of space(s) separated words. + graph_compiler: + Used to build num_graphs and den_graphs + den_scale: + The scale applied to the denominator tot_scores. + Returns: + Return a scalar loss. It is the sum over utterances in a batch, + without normalization. + """ + num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False) + + device = num_graphs.device + + num_fsas = num_graphs.shape[0] + assert dense_fsa_vec.dim0() == num_fsas + + assert den_graphs.shape[0] == 1 + + # The motivation to concatenate num_graphs and den_graphs + # is to reduce the number of calls to k2.intersect_dense. + num_den_graphs = k2.cat([num_graphs, den_graphs]) + + # NOTE: The a_to_b_map in k2.intersect_dense must be sorted + # so the following reorders num_den_graphs. + # + # The following code computes a_to_b_map + + # [0, 1, 2, ... ] + num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) + + # [num_fsas, num_fsas, num_fsas, ... ] + den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) + + # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] + num_den_graphs_indexes = ( + torch.stack([num_graphs_indexes, den_graphs_indexes]) + .t() + .reshape(-1) + .to(device) + ) + + num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) + + # [[0, 1, 2, ...]] + a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) + + # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] + a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) + + num_den_lats = k2.intersect_dense( + num_den_reordered_graphs, + dense_fsa_vec, + output_beam=10.0, + a_to_b_map=a_to_b_map, + ) + + num_den_tot_scores = num_den_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) + + num_tot_scores = num_den_tot_scores[::2] + den_tot_scores = num_den_tot_scores[1::2] + + tot_scores = num_tot_scores - den_scale * den_tot_scores + loss = -1 * tot_scores.sum() + return loss + + +def _compute_mmi_loss_exact_non_optimized( + dense_fsa_vec: k2.DenseFsaVec, + texts: List[str], + graph_compiler: BpeMmiTrainingGraphCompiler, + den_scale: float = 1.0, +) -> torch.Tensor: + """ + See :func:`_compute_mmi_loss_exact_optimized` for the meaning + of the arguments. + + It's more readable, though it invokes k2.intersect_dense twice. + + Note: + It uses less memory at the cost of speed. It is slower. + """ + num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) + + # TODO: pass output_beam as function argument + num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) + den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0) + + num_tot_scores = num_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) + + den_tot_scores = den_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) + + tot_scores = num_tot_scores - den_scale * den_tot_scores + + loss = -1 * tot_scores.sum() + return loss + + +def _compute_mmi_loss_pruned( + dense_fsa_vec: k2.DenseFsaVec, + texts: List[str], + graph_compiler: BpeMmiTrainingGraphCompiler, + den_scale: float = 1.0, +) -> torch.Tensor: + """ + See :func:`_compute_mmi_loss_exact_optimized` for the meaning + of the arguments. + + `pruned` means it uses k2.intersect_dense_pruned + + Note: + It uses the least amount of memory, but the loss is not exact due + to pruning. + """ + num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False) + + num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) + + # the values for search_beam/output_beam/min_active_states/max_active_states + # are not tuned. You may want to tune them. + den_lats = k2.intersect_dense_pruned( + den_graphs, + dense_fsa_vec, + search_beam=20.0, + output_beam=8.0, + min_active_states=30, + max_active_states=10000, + ) + + num_tot_scores = num_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) + + den_tot_scores = den_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) + + tot_scores = num_tot_scores - den_scale * den_tot_scores + + loss = -1 * tot_scores.sum() + return loss + + +class LFMMILoss(nn.Module): + """ + Computes Lattice-Free Maximum Mutual Information (LFMMI) loss. + + TODO: more detailed description + """ + + def __init__( + self, + graph_compiler: BpeMmiTrainingGraphCompiler, + use_pruned_intersect: bool = False, + den_scale: float = 1.0, + ): + super().__init__() + self.graph_compiler = graph_compiler + self.den_scale = den_scale + self.use_pruned_intersect = use_pruned_intersect + + def forward( + self, + dense_fsa_vec: k2.DenseFsaVec, + texts: List[str], + ) -> torch.Tensor: + """ + Args: + dense_fsa_vec: + It contains the neural network output. + texts: + A list of strings. Each string contains space(s) separated words. + Returns: + Return a scalar loss. It is the sum over utterances in a batch, + without normalization. + """ + if self.use_pruned_intersect: + func = _compute_mmi_loss_pruned + else: + func = _compute_mmi_loss_exact_non_optimized + # func = _compute_mmi_loss_exact_optimized + + return func( + dense_fsa_vec=dense_fsa_vec, + texts=texts, + graph_compiler=self.graph_compiler, + den_scale=self.den_scale, + ) diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py new file mode 100755 index 0000000000..58b721d219 --- /dev/null +++ b/icefall/shared/make_kn_lm.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) +# 2018 Ruizhe Huang +# Apache 2.0. + +# This is an implementation of computing Kneser-Ney smoothed language model +# in the same way as srilm. This is a back-off, unmodified version of +# Kneser-Ney smoothing, which produces the same results as the following +# command (as an example) of srilm: +# +# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ +# -text corpus.txt -lm lm.arpa +# +# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py +# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html + +import sys +import os +import re +import io +import math +import argparse +from collections import Counter, defaultdict + + +parser = argparse.ArgumentParser(description=""" + Generate kneser-ney language model as arpa format. By default, + it will read the corpus from standard input, and output to standard output. + """) +parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") +parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") +parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") +parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") +args = parser.parse_args() + +default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. + # Need to be very careful about the use of strip() and split() + # in this case, because there is a latin-1 whitespace character + # (nbsp) which is part of the unicode encoding range. + # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +strip_chars = " \t\r\n" +whitespace = re.compile("[ \t]+") + + +class CountsForHistory: + # This class (which is more like a struct) stores the counts seen in a + # particular history-state. It is used inside class NgramCounts. + # It really does the job of a dict from int to float, but it also + # keeps track of the total count. + def __init__(self): + # The 'lambda: defaultdict(float)' is an anonymous function taking no + # arguments that returns a new defaultdict(float). + self.word_to_count = defaultdict(int) + self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts + self.word_to_f = dict() # discounted probability + self.word_to_bow = dict() # back-off weight + self.total_count = 0 + + def words(self): + return self.word_to_count.keys() + + def __str__(self): + # e.g. returns ' total=12: 3->4, 4->6, -1->2' + return ' total={0}: {1}'.format( + str(self.total_count), + ', '.join(['{0} -> {1}'.format(word, count) + for word, count in self.word_to_count.items()])) + + def add_count(self, predicted_word, context_word, count): + assert count >= 0 + + self.total_count += count + self.word_to_count[predicted_word] += count + if context_word is not None: + self.word_to_context[predicted_word].add(context_word) + + +class NgramCounts: + # A note on data-structure. Firstly, all words are represented as + # integers. We store n-gram counts as an array, indexed by (history-length + # == n-gram order minus one) (note: python calls arrays "lists") of dicts + # from histories to counts, where histories are arrays of integers and + # "counts" are dicts from integer to float. For instance, when + # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd + # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an + # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. + def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): + assert ngram_order >= 2 + + self.ngram_order = ngram_order + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + self.counts = [] + for n in range(ngram_order): + self.counts.append(defaultdict(lambda: CountsForHistory())) + + self.d = [] # list of discounting factor for each order of ngram + + # adds a raw count (called while processing input data). + # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' + # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be + # 1. + def add_count(self, history, predicted_word, context_word, count): + self.counts[len(history)][history].add_count(predicted_word, context_word, count) + + # 'line' is a string containing a sequence of integer word-ids. + # This function adds the un-smoothed counts from this line of text. + def add_raw_counts_from_line(self, line): + if line == '': + words = [self.bos_symbol, self.eos_symbol] + else: + words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] + + for i in range(len(words)): + for n in range(1, self.ngram_order+1): + if i + n > len(words): + break + ngram = words[i: i + n] + predicted_word = ngram[-1] + history = tuple(ngram[: -1]) + if i == 0 or n == self.ngram_order: + context_word = None + else: + context_word = words[i-1] + + self.add_count(history, predicted_word, context_word, 1) + + def add_raw_counts_from_standard_input(self): + lines_processed = 0 + infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input + for line in infile: + line = line.strip(strip_chars) + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def add_raw_counts_from_file(self, filename): + lines_processed = 0 + with open(filename, encoding=default_encoding) as fp: + for line in fp: + line = line.strip(strip_chars) + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def cal_discounting_constants(self): + # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), + # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). + # This constant is used similarly to absolute discounting. + # Return value: d is a list of floats, where d[N+1] = D_N + + self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. + for n in range(1, self.ngram_order): + this_order_counts = self.counts[n] + n1 = 0 + n2 = 0 + for hist, counts_for_hist in this_order_counts.items(): + stat = Counter(counts_for_hist.word_to_count.values()) + n1 += stat[1] + n2 += stat[2] + assert n1 + 2 * n2 > 0 + self.d.append(n1 * 1.0 / (n1 + 2 * n2)) + + def cal_f(self): + # f(a_z) is a probability distribution of word sequence a_z. + # Typically f(a_z) is discounted to be less than the ML estimate so we have + # some leftover probability for the z words unseen in the context (a_). + # + # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams + # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w, c in counts_for_hist.word_to_count.items(): + counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + + n_star_star = 0 + for w in counts_for_hist.word_to_count.keys(): + n_star_star += len(counts_for_hist.word_to_context[w]) + + if n_star_star != 0: + for w in counts_for_hist.word_to_count.keys(): + n_star_z = len(counts_for_hist.word_to_context[w]) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + else: # patterns begin with , they do not have "modified count", so use raw count instead + for w in counts_for_hist.word_to_count.keys(): + n_star_z = counts_for_hist.word_to_count[w] + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + def cal_bow(self): + # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. + # Thus, two sorts of ngrams do not have a bow: + # 1) highest order ngram + # 2) ngrams ending in + # + # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) + # Note that Z1 is the set of all words with c(a_z) > 0 + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + counts_for_hist.word_to_bow[w] = None + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + if w == self.eos_symbol: + counts_for_hist.word_to_bow[w] = None + else: + a_ = hist + (w,) + + assert len(a_) < self.ngram_order + assert a_ in self.counts[len(a_)].keys() + + a_counts_for_hist = self.counts[len(a_)][a_] + + sum_z1_f_a_z = 0 + for u in a_counts_for_hist.word_to_count.keys(): + sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] + + sum_z1_f_z = 0 + _ = a_[1:] + _counts_for_hist = self.counts[len(_)][_] + for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 + sum_z1_f_z += _counts_for_hist.word_to_f[u] + + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) + + def print_raw_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) + res.sort(reverse=True) + for r in res: + print(r) + + def print_modified_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + modified_count = len(counts_for_hist.word_to_context[w]) + raw_count = counts_for_hist.word_to_count[w] + + if modified_count == 0: + res.append("{0}\t{1}".format(ngram, raw_count)) + else: + res.append("{0}\t{1}".format(ngram, modified_count)) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + res.append("{0}\t{1}".format(ngram, math.log(f, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f_and_bow(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + bow = counts_for_hist.word_to_bow[w] + if bow is None: + res.append("{1}\t{0}".format(ngram, math.log(f, 10))) + else: + res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): + # print as ARPA format. + + print('\\data\\', file=fout) + for hist_len in range(self.ngram_order): + # print the number of n-grams. + print('ngram {0}={1}'.format( + hist_len + 1, + sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), + file=fout + ) + + print('', file=fout) + + for hist_len in range(self.ngram_order): + print('\\{0}-grams:'.format(hist_len + 1), file=fout) + + this_order_counts = self.counts[hist_len] + for hist, counts_for_hist in this_order_counts.items(): + for word in counts_for_hist.word_to_count.keys(): + ngram = hist + (word,) + prob = counts_for_hist.word_to_f[word] + bow = counts_for_hist.word_to_bow[word] + + if prob == 0: # f() is always 0 + prob = 1e-99 + + line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) + if bow is not None: + line += '\t{0}'.format('%.7f' % math.log10(bow)) + print(line, file=fout) + print('', file=fout) + print('\\end\\', file=fout) + + +if __name__ == "__main__": + + ngram_counts = NgramCounts(args.ngram_order) + + if args.text is None: + ngram_counts.add_raw_counts_from_standard_input() + else: + assert os.path.isfile(args.text) + ngram_counts.add_raw_counts_from_file(args.text) + + ngram_counts.cal_discounting_constants() + ngram_counts.cal_f() + ngram_counts.cal_bow() + + if args.lm is None: + ngram_counts.print_as_arpa() + else: + with open(args.lm, 'w', encoding=default_encoding) as f: + ngram_counts.print_as_arpa(fout=f) diff --git a/test/test_bpe_mmi_graph_compiler.py b/test/test_bpe_mmi_graph_compiler.py new file mode 100644 index 0000000000..c6009d69b3 --- /dev/null +++ b/test/test_bpe_mmi_graph_compiler.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import copy +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler + + +def test_bpe_mmi_graph_compiler(): + lang_dir = Path("data/lang_bpe") + if lang_dir.is_dir() is False: + return + device = torch.device("cpu") + compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device) + + texts = ["HELLO WORLD", "MMI TRAINING"] + + num_graphs, den_graphs = compiler.compile(texts) + num_graphs.labels_sym = compiler.lexicon.token_table + num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table) + num_graphs.aux_labels_sym._id2sym[0] = "" + num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD") + num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD") + print(den_graphs.shape) + print(den_graphs[0].shape) + print(den_graphs[0].num_arcs) From 03242b33286215c7c7717aefe18e6e8b81a767eb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 7 Aug 2021 18:10:41 +0800 Subject: [PATCH 08/19] Remove unused files. --- .../ASR/local/ngram_entropy_pruning.py | 627 ------------------ egs/librispeech/ASR/prepare.sh | 42 +- requirements.txt | 1 + 3 files changed, 3 insertions(+), 667 deletions(-) delete mode 100644 egs/librispeech/ASR/local/ngram_entropy_pruning.py diff --git a/egs/librispeech/ASR/local/ngram_entropy_pruning.py b/egs/librispeech/ASR/local/ngram_entropy_pruning.py deleted file mode 100644 index d0ffa92f6f..0000000000 --- a/egs/librispeech/ASR/local/ngram_entropy_pruning.py +++ /dev/null @@ -1,627 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) -# Apache 2.0. - -# This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' -# in the same way as SRILM. - -################################################ -# Useful links/References: -################################################ -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2124 -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 -# https://github.com/sfischer13/python-arpa - -################################################ -# How to use: -################################################ -# python3 ngram_entropy_pruning.py -threshold $threshold -lm $input_lm -write-lm $pruned_lm - -################################################ -# SRILM commands: -################################################ -# to_prune_lm=egs/swbd/s5c/data/local/lm/sw1.o3g.kn.gz -# vocab=egs/swbd/s5c/data/local/lm/wordlist -# order=3 -# oov_symbol="" -# threshold=4.7e-5 -# pruned_lm=temp.${threshold}.gz -# ngram -unk -map-unk "$oov_symbol" -vocab $vocab -order $order -prune ${threshold} -lm ${to_prune_lm} -write-lm ${pruned_lm} -# -# lm= -# ngram -unk -lm $lm -ppl heldout -# ngram -unk -lm $lm -ppl heldout -debug 3 - -import argparse -import logging -import math - -import gzip -from io import StringIO -from collections import OrderedDict -from collections import defaultdict -from enum import Enum, unique -import re - -parser = argparse.ArgumentParser(description=""" - Prune an n-gram language model based on the relative entropy - between the original and the pruned model, based on Andreas Stolcke's paper. - An n-gram entry is removed, if the removal causes (training set) perplexity - of the model to increase by less than threshold relative. - - The command takes an arpa file and a pruning threshold as input, - and outputs a pruned arpa file. - """) -parser.add_argument("-threshold", - type=float, - default=1e-6, - help="Order of n-gram") -parser.add_argument("-lm", - type=str, - default=None, - help="Path to the input arpa file") -parser.add_argument("-write-lm", - type=str, - default=None, - help="Path to output arpa file after pruning") -parser.add_argument("-minorder", - type=int, - default=1, - help="The minorder parameter limits pruning to " - "ngrams of that length and above.") -parser.add_argument("-encoding", - type=str, - default="utf-8", - help="Encoding of the arpa file") -parser.add_argument("-verbose", - type=int, - default=2, - choices=[0, 1, 2, 3, 4, 5], - help="Verbose level, where " - "0 is most noisy; " - "5 is most silent") -args = parser.parse_args() - -default_encoding = args.encoding -logging.basicConfig( - format= - "%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", - level=args.verbose * 10) - - -class Context(dict): - """ - This class stores data for a context h. - It behaves like a python dict object, except that it has several - additional attributes. - """ - def __init__(self): - super().__init__() - self.log_bo = None - - -class Arpa: - """ - This is a class that implement the data structure of an APRA LM. - It (as well as some other classes) is modified based on the library - by Stefan Fischer: - https://github.com/sfischer13/python-arpa - """ - - UNK = '' - SOS = '' - EOS = '' - FLOAT_NDIGITS = 7 - base = 10 - - @staticmethod - def _check_input(my_input): - if not my_input: - raise ValueError - elif isinstance(my_input, tuple): - return my_input - elif isinstance(my_input, list): - return tuple(my_input) - elif isinstance(my_input, str): - return tuple(my_input.strip().split(' ')) - else: - raise ValueError - - @staticmethod - def _check_word(input_word): - if not isinstance(input_word, str): - raise ValueError - if ' ' in input_word: - raise ValueError - - def _replace_unks(self, words): - return tuple((w if w in self else self._unk) for w in words) - - def __init__(self, path=None, encoding=None, unk=None): - self._counts = OrderedDict() - self._ngrams = OrderedDict( - ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) - self._vocabulary = set() - if unk is None: - self._unk = self.UNK - - if path is not None: - self.loadf(path, encoding) - - def __contains__(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] - - def contains_word(self, word): - self._check_word(word) - return word in self._vocabulary - - def add_count(self, order, count): - self._counts[order] = count - self._ngrams[order - 1] = defaultdict(Context) - - def update_counts(self): - for order in range(1, self.order() + 1): - count = sum( - [len(wlist) for _, wlist in self._ngrams[order - 1].items()]) - if count > 0: - self._counts[order] = count - - def add_entry(self, ngram, p, bo=None, order=None): - # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - - # Note that p and bo here are in fact in the log domain (self.base = 10) - h_context = self._ngrams[len(h)][h] - h_context[w] = p - if bo is not None: - self._ngrams[len(ngram)][ngram].log_bo = bo - - for word in ngram: - self._vocabulary.add(word) - - def counts(self): - return sorted(self._counts.items()) - - def order(self): - return max(self._counts.keys(), default=None) - - def vocabulary(self, sort=True): - if sort: - return sorted(self._vocabulary) - else: - return self._vocabulary - - def _entries(self, order): - return (self._entry(h, w) - for h, wlist in self._ngrams[order - 1].items() for w in wlist) - - def _entry(self, h, w): - # return the entry for the ngram (h, w) - ngram = h + (w, ) - log_p = self._ngrams[len(h)][h][w] - log_bo = self._log_bo(ngram) - if log_bo is not None: - return round(log_p, self.FLOAT_NDIGITS), ngram, round( - log_bo, self.FLOAT_NDIGITS) - else: - return round(log_p, self.FLOAT_NDIGITS), ngram - - def _log_bo(self, ngram): - if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: - return self._ngrams[len(ngram)][ngram].log_bo - else: - return None - - def _log_p(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: - return self._ngrams[len(h)][h][w] - else: - return None - - def log_p_raw(self, ngram): - log_p = self._log_p(ngram) - if log_p is not None: - return log_p - else: - if len(ngram) == 1: - raise KeyError - else: - log_bo = self._log_bo(ngram[:-1]) - if log_bo is None: - log_bo = 0 - return log_bo + self.log_p_raw(ngram[1:]) - - def log_joint_prob(self, sequence): - # Compute the joint prob of the sequence based on the chain rule - # Note that sequence should be a tuple of strings - # - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 - - log_joint_p = 0 - seq = sequence - while len(seq) > 0: - log_joint_p += self.log_p_raw(seq) - seq = seq[:-1] - - # If we're computing the marginal probability of the unigram - # context we have to look up instead since the former - # has prob = 0. - if len(seq) == 1 and seq[0] == self.SOS: - seq = (self.EOS, ) - - return log_joint_p - - def set_new_context(self, h): - old_context = self._ngrams[len(h)][h] - self._ngrams[len(h)][h] = Context() - return old_context - - def log_p(self, ngram): - words = self._check_input(ngram) - if self._unk: - words = self._replace_unks(words) - return self.log_p_raw(words) - - def log_s(self, sentence, sos=SOS, eos=EOS): - words = self._check_input(sentence) - if self._unk: - words = self._replace_unks(words) - if sos: - words = (sos, ) + words - if eos: - words = words + (eos, ) - result = sum( - self.log_p_raw(words[:i]) for i in range(1, - len(words) + 1)) - if sos: - result = result - self.log_p_raw(words[:1]) - return result - - def p(self, ngram): - return self.base**self.log_p(ngram) - - def s(self, sentence): - return self.base**self.log_s(sentence) - - def write(self, fp): - fp.write('\n\\data\\\n') - for order, count in self.counts(): - fp.write('ngram {}={}\n'.format(order, count)) - fp.write('\n') - for order, _ in self.counts(): - fp.write('\\{}-grams:\n'.format(order)) - for e in self._entries(order): - prob = e[0] - ngram = ' '.join(e[1]) - if len(e) == 2: - fp.write('{}\t{}\n'.format(prob, ngram)) - elif len(e) == 3: - backoff = e[2] - fp.write('{}\t{}\t{}\n'.format(prob, ngram, backoff)) - else: - raise ValueError - fp.write('\n') - fp.write('\\end\\\n') - - -class ArpaParser: - """ - This is a class that implement a parser of an arpa file - """ - @unique - class State(Enum): - DATA = 1 - COUNT = 2 - HEADER = 3 - ENTRY = 4 - - re_count = re.compile(r'^ngram (\d+)=(\d+)$') - re_header = re.compile(r'^\\(\d+)-grams:$') - re_entry = re.compile('^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)' - '\t' - '(\\S+( \\S+)*)' - '(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$') - - def _parse(self, fp): - self._result = [] - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - for line in fp: - line = line.strip() - if self._state == self.State.DATA: - self._data(line) - elif self._state == self.State.COUNT: - self._count(line) - elif self._state == self.State.HEADER: - self._header(line) - elif self._state == self.State.ENTRY: - self._entry(line) - if self._state != self.State.DATA: - raise Exception(line) - return self._result - - def _data(self, line): - if line == '\\data\\': - self._state = self.State.COUNT - self._tmp_model = Arpa() - else: - pass # skip comment line - - def _count(self, line): - match = self.re_count.match(line) - if match: - order = match.group(1) - count = match.group(2) - self._tmp_model.add_count(int(order), int(count)) - elif not line: - self._state = self.State.HEADER # there are no counts - else: - raise Exception(line) - - def _header(self, line): - match = self.re_header.match(line) - if match: - self._state = self.State.ENTRY - self._tmp_order = int(match.group(1)) - elif line == '\\end\\': - self._result.append(self._tmp_model) - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - elif not line: - pass # skip empty line - else: - raise Exception(line) - - def _entry(self, line): - match = self.re_entry.match(line) - if match: - p = self._float_or_int(match.group(1)) - ngram = tuple(match.group(4).split(' ')) - bo_match = match.group(7) - bo = self._float_or_int(bo_match) if bo_match else None - self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) - elif not line: - self._state = self.State.HEADER # last entry - else: - raise Exception(line) - - @staticmethod - def _float_or_int(s): - f = float(s) - i = int(f) - if str(i) == s: # don't drop trailing ".0" - return i - else: - return f - - def load(self, fp): - """Deserialize fp (a file-like object) to a Python object.""" - return self._parse(fp) - - def loadf(self, path, encoding=None): - """Deserialize path (.arpa, .gz) to a Python object.""" - path = str(path) - if path.endswith('.gz'): - with gzip.open(path, mode='rt', encoding=encoding) as f: - return self.load(f) - else: - with open(path, mode='rt', encoding=encoding) as f: - return self.load(f) - - def loads(self, s): - """Deserialize s (a str) to a Python object.""" - with StringIO(s) as f: - return self.load(f) - - def dump(self, obj, fp): - """Serialize obj to fp (a file-like object) in ARPA format.""" - obj.write(fp) - - def dumpf(self, obj, path, encoding=None): - """Serialize obj to path in ARPA format (.arpa, .gz).""" - path = str(path) - if path.endswith('.gz'): - with gzip.open(path, mode='wt', encoding=encoding) as f: - return self.dump(obj, f) - else: - with open(path, mode='wt', encoding=encoding) as f: - self.dump(obj, f) - - def dumps(self, obj): - """Serialize obj to an ARPA formatted str.""" - with StringIO() as f: - self.dump(obj, f) - return f.getvalue() - - -def add_log_p(prev_log_sum, log_p, base): - return math.log(base**log_p + base**prev_log_sum, base) - - -def compute_numerator_denominator(lm, h): - log_sum_seen_h = -math.inf - log_sum_seen_h_lower = -math.inf - base = lm.base - for w, log_p in lm._ngrams[len(h)][h].items(): - log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) - - ngram = h + (w, ) - log_p_lower = lm.log_p_raw(ngram[1:]) - log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, - base) - - numerator = 1.0 - base**log_sum_seen_h - denominator = 1.0 - base**log_sum_seen_h_lower - return numerator, denominator - - -def prune(lm, threshold, minorder): - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 - - for i in range(lm.order(), max(minorder - 1, 1), - -1): # i is the order of the ngram (h, w) - logging.info("processing %d-grams ..." % i) - count_pruned_ngrams = 0 - - h_dict = lm._ngrams[i - 1] - for h in list(h_dict.keys()): - # old backoff weight, BOW(h) - log_bow = lm._log_bo(h) - if log_bow is None: - log_bow = 0 - - # Compute numerator and denominator of the backoff weight, - # so that we can quickly compute the BOW adjustment due to - # leaving out one prob. - numerator, denominator = compute_numerator_denominator(lm, h) - - # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 - - # Compute the marginal probability of the context, P(h) - h_log_p = lm.log_joint_prob(h) - - all_pruned = True - pruned_w_set = set() - - for w, log_p in h_dict[h].items(): - ngram = h + (w, ) - - # lower-order estimate for ngramProb, P(w|h') - backoff_prob = lm.log_p_raw(ngram[1:]) - - # Compute BOW after removing ngram, BOW'(h) - new_log_bow = math.log(numerator + lm.base ** log_p, lm.base) - \ - math.log(denominator + lm.base ** backoff_prob, lm.base) - - # Compute change in entropy due to removal of ngram - delta_prob = backoff_prob + new_log_bow - log_p - delta_entropy = - (lm.base ** h_log_p) * \ - ((lm.base ** log_p) * delta_prob + - numerator * (new_log_bow - log_bow)) - - # compute relative change in model (training set) perplexity - perp_change = lm.base**delta_entropy - 1.0 - - pruned = threshold > 0 and perp_change < threshold - - # Make sure we don't prune ngrams whose backoff nodes are needed - if pruned and \ - len(ngram) in lm._ngrams and \ - len(lm._ngrams[len(ngram)][ngram]) > 0: - pruned = False - - logging.debug("CONTEXT " + str(h) + " WORD " + w + - " CONTEXTPROB %f " % h_log_p + - " OLDPROB %f " % log_p + " NEWPROB %f " % - (backoff_prob + new_log_bow) + - " DELTA-H %f " % delta_entropy + - " DELTA-LOGP %f " % delta_prob + - " PPL-CHANGE %f " % perp_change + " PRUNED " + - str(pruned)) - - if pruned: - pruned_w_set.add(w) - count_pruned_ngrams += 1 - else: - all_pruned = False - - # If we removed all ngrams for this context we can - # remove the context itself, but only if the present - # context is not a prefix to a longer one. - if all_pruned and len(pruned_w_set) == len(h_dict[h]): - del h_dict[ - h] # this context h is no longer needed, as its ngram prob is stored at its own context h' - elif len(pruned_w_set) > 0: - # The pruning for this context h is actually done here - old_context = lm.set_new_context(h) - - for w, p_w in old_context.items(): - if w not in pruned_w_set: - lm.add_entry( - h + (w, ), - p_w) # the entry hw is stored at the context h - - # We need to recompute the back-off weight, but - # this can only be done after completing the pruning - # of the lower-order ngrams. - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 - - logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) - - # recompute backoff weights - for i in range(max(minorder - 1, 1) + 1, - lm.order() + - 1): # be careful of this order: from low- to high-order - for h in lm._ngrams[i - 1]: - numerator, denominator = compute_numerator_denominator(lm, h) - new_log_bow = math.log(numerator, lm.base) - math.log( - denominator, lm.base) - lm._ngrams[len(h)][h].log_bo = new_log_bow - - # update counts - lm.update_counts() - - return - - -def check_h_is_valid(lm, h): - sum_under_h = sum( - [lm.base**lm.log_p_raw(h + (w, )) for w in lm.vocabulary(sort=False)]) - if abs(sum_under_h - 1.0) > 1e-6: - logging.info("warning: %s %f" % (str(h), sum_under_h)) - return False - else: - return True - - -def validate_lm(lm): - # sanity check if the conditional probability sums to one under each context h - for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) - logging.info("validating %d-grams ..." % i) - h_dict = lm._ngrams[i - 1] - for h in h_dict.keys(): - check_h_is_valid(lm, h) - - -def compare_two_apras(path1, path2): - pass - - -if __name__ == '__main__': - # load an arpa file - logging.info("Loading the arpa file from %s" % args.lm) - parser = ArpaParser() - models = parser.loadf(args.lm, encoding=default_encoding) - lm = models[0] # ARPA files may contain several models. - logging.info("Stats before pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - - # prune it, the language model will be modified in-place - logging.info("Start pruning the model with threshold=%.3E..." % - args.threshold) - prune(lm, args.threshold, args.minorder) - - # validate_lm(lm) - - # write the arpa language model to a file - logging.info("Stats after pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - logging.info("Saving the pruned arpa file to %s" % args.write_lm) - parser.dumpf(lm, args.write_lm, encoding=default_encoding) - logging.info("Done.") diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 375da0d797..c8e0931777 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -159,51 +159,13 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then -lm data/lang_bpe/P.arpa fi - # TODO: Use egs/wsj/s5/utils/lang/ngram_entropy_pruning.py - # from kaldi to prune P if it causes OOM later - - if [ ! -f data/lang_bpe/P-no-prune.fst.txt ]; then + if [ ! -f data/lang_bpe/P.fst.txt ]; then python3 -m kaldilm \ --read-symbol-table="data/lang_bpe/tokens.txt" \ --disambig-symbol='#0' \ --max-order=2 \ - data/lang_bpe/P.arpa > data/lang_bpe/P-no-prune.fst.txt + data/lang_bpe/P.arpa > data/lang_bpe/P.fst.txt fi - - thresholds=( - 1e-6 - 1e-7 - ) - for threshold in ${thresholds[@]}; do - if [ ! -f data/lang_bpe/P-pruned.${threshold}.arpa ]; then - python3 ./local/ngram_entropy_pruning.py \ - -threshold $threshold \ - -lm data/lang_bpe/P.arpa \ - -write-lm data/lang_bpe/P-pruned.${threshold}.arpa - fi - - if [ ! -f data/lang_bpe/P-pruned.${threshold}.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_bpe/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - data/lang_bpe/P-pruned.${threshold}.arpa > data/lang_bpe/P-pruned.${threshold}.fst.txt - fi - done - - if [ ! -f data/lang_bpe/P-uni.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_bpe/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=1 \ - data/lang_bpe/P.arpa > data/lang_bpe/P-uni.fst.txt - fi - - ( cd data/lang_bpe; - # ln -sfv P-pruned.1e-6.fst.txt P.fst.txt - ln -sfv P-no-prune.fst.txt P.fst.txt - ) - rm -fv data/lang_bpe/P.pt data/lang_bpe/ctc_topo_P.pt fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then diff --git a/requirements.txt b/requirements.txt index a54edf118d..710048fede 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ kaldilm kaldialign sentencepiece>=0.1.96 +tensorboard From 56319b090350299b3d499c81c35a00a89de146ad Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Aug 2021 17:03:05 +0800 Subject: [PATCH 09/19] Minor fixes. --- egs/librispeech/ASR/conformer_mmi/train.py | 102 +++++++++++++-- egs/librispeech/ASR/local/compile_hlg.py | 43 ++++--- egs/librispeech/ASR/local/prepare_lang_bpe.py | 35 +++--- egs/librispeech/ASR/local/train_bpe_model.py | 40 ++++-- egs/librispeech/ASR/prepare.sh | 116 +++++++++++------- 5 files changed, 240 insertions(+), 96 deletions(-) diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 810a0a4dfe..f11291bbf8 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -13,7 +13,9 @@ import torch.nn as nn from conformer import Conformer from lhotse.utils import fix_random_seed +from tdnn_lstm_ctc.model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -58,6 +60,26 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--use-ali-model", + type=str2bool, + default=True, + help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py " + "and you have some checkpoints inside the directory " + "tdnn_lstm_ctc/exp_bpe_500 ." + "It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt " + "as the pre-trained alignment model", + ) + parser.add_argument( + "--ali-model-epoch", + type=int, + default=19, + help="If --use-ali-model is True, load " + "tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as " + "the alignment model." + "Used only if --use-ali-model is True.", + ) + # TODO: add extra arguments and support DDP training. # Currently, only single GPU training is implemented. Will add # DDP training once single GPU training is finished. @@ -117,24 +139,21 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_mmi/exp"), - "lang_dir": Path("data/lang_bpe"), + "exp_dir": Path("conformer_mmi/exp_500"), + "lang_dir": Path("data/lang_bpe_500"), "feature_dim": 80, "weight_decay": 1e-6, "subsampling_factor": 4, "start_epoch": 0, - "num_epochs": 10, + "num_epochs": 50, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, - # It takes about 10 minutes (1 GPU, max_duration=200) - # to run a validation process. - # For the 100 h subset, there are 85617 batches. - # For the 960 h dataset, there are 843723 batches - "valid_interval": 8000, + "reset_interval": 200, + "valid_interval": 10, "use_pruned_intersect": False, "den_scale": 1.0, # @@ -242,6 +261,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: nn.Module, + ali_model: Optional[nn.Module], batch: dict, graph_compiler: BpeMmiTrainingGraphCompiler, is_training: bool, @@ -274,6 +294,22 @@ def compute_loss( with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is [N, T, C] + if ali_model is not None and params.batch_idx_train < 4000: + feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T] + ali_model_output = ali_model(feature) + # subsampling is done slightly differently, may be small length + # differences. + min_len = min(ali_model_output.shape[1], nnet_output.shape[1]) + # scale less than one so it will be encouraged + # to mimic ali_model's output + ali_model_scale = 500.0 / (params.batch_idx_train + 500) + + # Use clone() here or log-softmax backprop will fail. + nnet_output = nnet_output.clone() + + nnet_output[:, :min_len, :] += ( + ali_model_scale * ali_model_output[:, :min_len, :] + ) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -337,6 +373,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: nn.Module, + ali_model: Optional[nn.Module], graph_compiler: BpeMmiTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -354,6 +391,7 @@ def compute_validation_loss( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, + ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=False, @@ -394,6 +432,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: nn.Module, + ali_model: Optional[nn.Module], optimizer: torch.optim.Optimizer, graph_compiler: BpeMmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, @@ -412,6 +451,9 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. + ali_model: + The force alignment model for training. It is from + tdnn_lstm_ctc/train_bpe.py optimizer: The optimizer we are using. graph_compiler: @@ -432,7 +474,8 @@ def train_one_epoch( tot_att_loss = 0.0 tot_frames = 0.0 # sum of frames over all batches - + params.tot_loss = 0.0 + params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -440,6 +483,7 @@ def train_one_epoch( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, + ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=True, @@ -450,6 +494,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() + clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() @@ -461,6 +506,9 @@ def train_one_epoch( tot_mmi_loss += mmi_loss_cpu tot_att_loss += att_loss_cpu + params.tot_frames += params.train_frames + params.tot_loss += loss_cpu + tot_avg_loss = tot_loss / tot_frames tot_avg_mmi_loss = tot_mmi_loss / tot_frames tot_avg_att_loss = tot_att_loss / tot_frames @@ -509,11 +557,18 @@ def train_one_epoch( tot_avg_loss, params.batch_idx_train, ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: compute_validation_loss( params=params, model=model, + ali_model=ali_model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, @@ -544,7 +599,7 @@ def train_one_epoch( params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + params.train_loss = params.tot_loss / params.tot_frames if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -624,6 +679,32 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) + if args.use_ali_model: + ali_model = TdnnLstm( + num_features=params.feature_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + ) + + ali_model_fname = Path( + f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt" + ) + assert ( + ali_model_fname.is_file() + ), f"ali model filename {ali_model_fname} does not exist!" + + ali_model.load_state_dict( + torch.load(ali_model_fname, map_location="cpu")["model"] + ) + ali_model.to(device) + + ali_model.eval() + ali_model.requires_grad_(False) + logging.info(f"Use ali_model: {ali_model_fname}") + else: + ali_model = None + logging.info("No ali_model") + librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() @@ -646,6 +727,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + ali_model=ali_model, optimizer=optimizer, graph_compiler=graph_compiler, train_dl=train_dl, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index b304021616..9f28bb74d6 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -1,18 +1,18 @@ #!/usr/bin/env python3 """ -This script compiles HLG from +This script takes as input lang_dir and generates HLG from - - H, the ctc topology, built from tokens contained in lexicon.txt - - L, the lexicon, built from L_disambig.pt + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt Caution: We use a lexicon that contains disambiguation symbols - G, the LM, built from data/lm/G_3_gram.fst.txt -The generated HLG is saved in data/lm/HLG.pt (phone based) -or data/lm/HLG_bpe.pt (BPE based) +The generated HLG is saved in $lang_dir/HLG.pt """ +import argparse import logging from pathlib import Path @@ -22,11 +22,23 @@ from icefall.lexicon import Lexicon +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe. + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. Return: An FSA representing HLG. @@ -104,17 +116,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: def main(): - for d in ["data/lang_phone", "data/lang_bpe"]: - d = Path(d) - logging.info(f"Processing {d}") + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return - if (d / "HLG.pt").is_file(): - logging.info(f"{d}/HLG.pt already exists - skipping") - continue + logging.info(f"Processing {lang_dir}") - HLG = compile_HLG(d) - logging.info(f"Saving HLG.pt to {d}") - torch.save(HLG.as_dict(), f"{d}/HLG.pt") + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e31220d9b2..68b8db9667 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -3,12 +3,13 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This script takes as inputs the following two files: - - data/lang_bpe/bpe.model, - - data/lang_bpe/words.txt +This script takes as input `lang_dir`, which should contain:: -and generates the following files in the directory data/lang_bpe: + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: - lexicon.txt - lexicon_disambig.txt @@ -17,6 +18,7 @@ - tokens.txt """ +import argparse from pathlib import Path from typing import Dict, List, Tuple @@ -141,8 +143,22 @@ def generate_lexicon( return lexicon, token2id +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + return parser.parse_args() + + def main(): - lang_dir = Path("data/lang_bpe") + args = get_args() + lang_dir = Path(args.lang_dir) model_file = lang_dir / "bpe.model" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") @@ -189,15 +205,6 @@ def main(): torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(lang_dir / "L.svg", title="L") - L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig") - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 59746ad9a6..9872a7c6ac 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -1,10 +1,5 @@ #!/usr/bin/env python3 -""" -This script takes as input "data/lang/bpe/train.txt" -and generates "data/lang/bpe/bep.model". -""" - # You can install sentencepiece via: # # pip install sentencepiece @@ -14,17 +9,41 @@ # # Please install a version >=0.1.96 +import argparse import shutil from pathlib import Path import sentencepiece as spm +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the training corpus: train.txt. + The generated bpe.model is saved to this directory. + """, + ) + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + model_type = "unigram" - vocab_size = 5000 - model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" - train_text = "data/lang_bpe/train.txt" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = f"{lang_dir}/train.txt" character_coverage = 1.0 input_sentence_size = 100000000 @@ -49,10 +68,7 @@ def main(): eos_id=-1, ) - sp = spm.SentencePieceProcessor(model_file=str(model_file)) - vocab_size = sp.vocab_size() - - shutil.copyfile(model_file, "data/lang_bpe/bpe.model") + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index c8e0931777..6479973bfb 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -36,8 +36,17 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 - -# All generated files by this script are saved in "data" +# vocab size for sentence piece models. +# It will generate data/lang_bpe_500, data/lang_bpe_1000, +# and data/lang_bpe_5000. +vocab_sizes=( + 500 + 1000 + 5000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. mkdir -p data log() { @@ -116,56 +125,68 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi + if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "State 6: Prepare BPE based lang" - mkdir -p data/lang_bpe - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt data/lang_bpe/ - - if [ ! -f data/lang_bpe/train.txt ]; then - log "Generate data for BPE training" - files=$( - find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" - find "data/LibriSpeech/train-clean-360" -name "*.trans.txt" - find "data/LibriSpeech/train-other-500" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > data/lang_bpe/train.txt - fi - python3 ./local/train_bpe_model.py - - if [ ! -f data/lang_bpe/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py - fi + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/train.txt ]; then + log "Generate data for BPE training" + files=$( + find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "data/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "data/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $lang_dir/train.txt + fi + + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi + done fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Prepare bigram P" - if [ ! -f data/lang_bpe/corpus.txt ]; then - ./local/convert_transcript_to_corpus.py \ - --lexicon data/lang_bpe/lexicon.txt \ - --transcript data/lang_bpe/train.txt \ - --oov "" \ - > data/lang_bpe/corpus.txt - fi - if [ ! -f data/lang_bpe/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text data/lang_bpe/corpus.txt \ - -lm data/lang_bpe/P.arpa - fi - - if [ ! -f data/lang_bpe/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_bpe/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - data/lang_bpe/P.arpa > data/lang_bpe/P.fst.txt - fi + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/corpus.txt ]; then + ./local/convert_transcript_to_corpus.py \ + --lexicon data/lang_bpe/lexicon.txt \ + --transcript data/lang_bpe/train.txt \ + --oov "" \ + > $lang_dir/corpus.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/corpus.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then @@ -195,5 +216,10 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then log "Stage 9: Compile HLG" - python3 ./local/compile_hlg.py + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done fi From 6f5d63492a32f4e48b60ae6d71dc529ddefef22d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 24 Sep 2021 19:45:08 +0800 Subject: [PATCH 10/19] Refactoring. --- .../ASR/conformer_ctc/transformer.py | 6 +- .../ASR/conformer_mmi/asr_datamodule.py | 354 ++++++++++++++++++ .../ASR/conformer_mmi/conformer.py | 48 ++- egs/librispeech/ASR/conformer_mmi/decode.py | 222 ++++++++--- .../ASR/conformer_mmi/transformer.py | 84 +++-- ... => convert_transcript_words_to_tokens.py} | 21 +- .../ASR/local/generate_unique_lexicon.py | 100 +++++ egs/librispeech/ASR/local/prepare_lang.py | 64 +++- egs/librispeech/ASR/local/prepare_lang_bpe.py | 28 ++ egs/librispeech/ASR/local/train_bpe_model.py | 12 +- egs/librispeech/ASR/prepare.sh | 36 +- icefall/bpe_graph_compiler.py | 6 +- icefall/bpe_mmi_graph_compiler.py | 178 --------- icefall/lexicon.py | 158 +++++--- icefall/mmi_graph_compiler.py | 216 +++++++++++ icefall/utils.py | 89 ++++- test/test_bpe_graph_compiler.py | 9 +- test/test_bpe_mmi_graph_compiler.py | 30 -- test/test_lexicon.py | 177 ++++++--- test/test_mmi_graph_compiler.py | 196 ++++++++++ 20 files changed, 1545 insertions(+), 489 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_mmi/asr_datamodule.py rename egs/librispeech/ASR/local/{convert_transcript_to_corpus.py => convert_transcript_words_to_tokens.py} (83%) create mode 100755 egs/librispeech/ASR/local/generate_unique_lexicon.py delete mode 100644 icefall/bpe_mmi_graph_compiler.py create mode 100644 icefall/mmi_graph_compiler.py delete mode 100644 test/test_bpe_mmi_graph_compiler.py mode change 100644 => 100755 test/test_lexicon.py create mode 100755 test/test_mmi_graph_compiler.py diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index f1d7cbbbc7..68a4ff65cb 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -114,7 +114,10 @@ def __init__( norm=encoder_norm, ) - self.encoder_output_layer = nn.Linear(d_model, num_classes) + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) if num_decoder_layers > 0: self.decoder_num_class = ( @@ -325,6 +328,7 @@ def decoder_nll( """ # The common part between this function and decoder_forward could be # extracted as a separate function. + ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) diff --git a/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py new file mode 100644 index 0000000000..8290e71d13 --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py @@ -0,0 +1,354 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List, Union + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class LibriSpeechAsrDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") + + logging.info("About to create train dataset") + transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [ + SpecAugment( + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ] + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get dev cuts") + cuts_valid = self.valid_cuts() + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = SingleCutSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: + cuts = self.test_cuts() + is_list = isinstance(cuts, list) + test_loaders = [] + if not is_list: + cuts = [cuts] + + for cuts_test in cuts: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = SingleCutSampler( + cuts_test, max_duration=self.args.max_duration + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, batch_size=None, sampler=sampler, num_workers=1 + ) + test_loaders.append(test_dl) + + if is_list: + return test_loaders + else: + return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json.gz" + ) + if self.args.full_libri: + cuts_train = ( + cuts_train + + load_manifest( + self.args.feature_dir / "cuts_train-clean-360.json.gz" + ) + + load_manifest( + self.args.feature_dir / "cuts_train-other-500.json.gz" + ) + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev-clean.json.gz" + ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test-clean", "test-other"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index ac49b7b1c4..b19b94db1d 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -1,7 +1,20 @@ #!/usr/bin/env python3 - # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings @@ -43,7 +56,6 @@ def __init__( cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - is_espnet_structure: bool = False, use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( @@ -70,12 +82,10 @@ def __init__( dropout, cnn_module_kernel, normalize_before, - is_espnet_structure, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before - self.is_espnet_structure = is_espnet_structure - if self.normalize_before and self.is_espnet_structure: + if self.normalize_before: self.after_norm = nn.LayerNorm(d_model) else: # Note: TorchScript detects that self.after_norm could be used inside forward() @@ -88,7 +98,7 @@ def run_encoder( """ Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -110,7 +120,7 @@ def run_encoder( mask = mask.to(x.device) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - if self.normalize_before and self.is_espnet_structure: + if self.normalize_before: x = self.after_norm(x) return x, mask @@ -144,11 +154,10 @@ def __init__( dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( @@ -394,7 +403,7 @@ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: :, self.pe.size(1) // 2 - x.size(1) - + 1 : self.pe.size(1) // 2 + + 1 : self.pe.size(1) // 2 # noqa E203 + x.size(1), ] return self.dropout(x), self.dropout(pos_emb) @@ -421,7 +430,6 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, - is_espnet_structure: bool = False, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -444,8 +452,6 @@ def __init__( self._reset_parameters() - self.is_espnet_structure = is_espnet_structure - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -675,9 +681,6 @@ def multi_head_attention_forward( _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if not self.is_espnet_structure: - q = q * scaling - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -770,14 +773,9 @@ def multi_head_attention_forward( ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - if not self.is_espnet_structure: - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) - else: - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 6030d13e1b..dc2e449c22 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -1,8 +1,20 @@ #!/usr/bin/env python3 - # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -# (still working in progress) import argparse import logging @@ -13,14 +25,15 @@ import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, + nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, @@ -32,6 +45,7 @@ get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -44,51 +58,111 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=9, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=1, + default=20, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_mmi/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_mmi/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 8, "attention_dim": 512, - "subsampling_factor": 4, "num_decoder_layers": 6, - "vgg_frontend": False, - "is_espnet_structure": True, - "use_feat_batchnorm": True, + # parameters for decoding "search_beam": 20, "output_beam": 8, "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - # - nbest-rescoring - # - whole-lattice-rescoring - # - attention-decoder - # "method": "whole-lattice-rescoring", - "method": "1best", - # num_paths is used when method is "nbest", "nbest-rescoring", - # and attention-decoder - "num_paths": 100, } ) return params @@ -99,7 +173,7 @@ def decode_one_batch( model: nn.Module, HLG: k2.Fsa, batch: dict, - lexicon: Lexicon, + word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, @@ -133,8 +207,8 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - lexicon: - It contains word symbol table. + word_table: + The word symbol table. sos_id: The token ID of the SOS. eos_id: @@ -151,12 +225,12 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) supervision_segments = torch.stack( ( @@ -178,6 +252,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + lattice_score_scale=params.lattice_score_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + return {key: hyps} + if params.method in ["1best", "nbest"]: if params.method == "1best": best_path = one_best_decoding( @@ -189,11 +281,12 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + lattice_score_scale=params.lattice_score_scale, ) - key = f"no_rescore-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + hyps = [[word_table[i] for i in ids] for ids in hyps] return {key: hyps} assert params.method in [ @@ -202,7 +295,8 @@ def decode_one_batch( "attention-decoder", ] - lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": @@ -211,16 +305,23 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + lattice_score_scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, ) elif params.method == "attention-decoder": # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` best_path_dict = rescore_with_attention_decoder( lattice=rescored_lattice, @@ -230,15 +331,20 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, + lattice_score_scale=params.lattice_score_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + for lm_scale in lm_scale_list: + ans[lm_scale_str] = [[] * lattice.shape[0]] return ans @@ -247,7 +353,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, HLG: k2.Fsa, - lexicon: Lexicon, + word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, @@ -263,8 +369,8 @@ def decode_dataset( The neural model. HLG: The decoding graph. - lexicon: - It contains word symbol table. + word_table: + It is the word symbol table. sos_id: The token ID for SOS. eos_id: @@ -283,7 +389,11 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -294,7 +404,7 @@ def decode_dataset( model=model, HLG=HLG, batch=batch, - lexicon=lexicon, + word_table=word_table, G=G, sos_id=sos_id, eos_id=eos_id, @@ -312,10 +422,10 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results @@ -374,8 +484,10 @@ def main(): params = get_params() params.update(vars(args)) + params.exp_dir = Path(params.exp_dir) + params.lang_dir = Path(params.lang_dir) - setup_logger(f"{params.exp_dir}/log/log-decode") + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") logging.info("Decoding started") logging.info(params) @@ -389,7 +501,7 @@ def main(): logging.info(f"device: {device}") - graph_compiler = BpeMmiTrainingGraphCompiler( + graph_compiler = BpeCtcTrainingGraphCompiler( params.lang_dir, device=device, sos_token="", @@ -398,7 +510,9 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -429,7 +543,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: @@ -453,7 +567,6 @@ def main(): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, - is_espnet_structure=params.is_espnet_structure, use_feat_batchnorm=params.use_feat_batchnorm, ) @@ -468,6 +581,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) @@ -487,7 +607,7 @@ def main(): params=params, model=model, HLG=HLG, - lexicon=lexicon, + word_table=lexicon.word_table, G=G, sos_id=sos_id, eos_id=eos_id, diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index fd1a082e7c..68a4ff65cb 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -1,15 +1,26 @@ -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import Dict, List, Optional, Tuple -import k2 import torch import torch.nn as nn from subsampling import Conv2dSubsampling, VggSubsampling - -from icefall.utils import get_texts from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. @@ -72,8 +83,8 @@ def __init__( if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model @@ -103,10 +114,15 @@ def __init__( norm=encoder_norm, ) - self.encoder_output_layer = nn.Linear(d_model, num_classes) + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) if num_decoder_layers > 0: - self.decoder_num_class = self.num_classes + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol self.decoder_embed = nn.Embedding( num_embeddings=self.decoder_num_class, embedding_dim=d_model @@ -146,7 +162,7 @@ def forward( """ Args: x: - The input tensor. Its shape is [N, T, C]. + The input tensor. Its shape is (N, T, C). supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -155,17 +171,17 @@ def forward( Returns: Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and value for the decoder. - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. + memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision ) @@ -179,7 +195,7 @@ def run_encoder( Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -190,8 +206,8 @@ def run_encoder( padding mask for the decoder. Returns: Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). The mask is None if `supervisions` is None. It is used as memory key padding mask in the decoder. """ @@ -209,11 +225,11 @@ def ctc_output(self, x: torch.Tensor) -> torch.Tensor: Args: x: The output tensor from the transformer encoder. - Its shape is [T, N, C] + Its shape is (T, N, C) Returns: Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] + Its shape is (N, T, C) """ x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -231,7 +247,7 @@ def decoder_forward( """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -296,7 +312,7 @@ def decoder_nll( """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -312,6 +328,7 @@ def decoder_nll( """ # The common part between this function and decoder_forward could be # extracted as a separate function. + ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) @@ -329,6 +346,9 @@ def decoder_nll( ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. tgt_key_padding_mask[:, 0] = False tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) @@ -634,13 +654,13 @@ def __init__(self, d_model: int, dropout: float = 0.1) -> None: def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. Args: x: - It is a tensor of shape [N, T, C]. + It is a tensor of shape (N, T, C). Returns: Return None. """ @@ -658,7 +678,7 @@ def extend_pe(self, x: torch.Tensor) -> None: pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) + # Now pe is of shape (1, T, d_model), where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -667,10 +687,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Args: x: - Its shape is [N, T, C] + Its shape is (N, T, C) Returns: - Return a tensor of shape [N, T, C] + Return a tensor of shape (N, T, C) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1), :] @@ -766,7 +786,8 @@ def load_state_dict(self, state_dict): class LabelSmoothingLoss(nn.Module): """ - Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + Label-smoothing loss. KL-divergence between + q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa @@ -851,7 +872,8 @@ def encoder_padding_mask( frames, before subsampling) Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. """ if supervisions is None: return None diff --git a/egs/librispeech/ASR/local/convert_transcript_to_corpus.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py similarity index 83% rename from egs/librispeech/ASR/local/convert_transcript_to_corpus.py rename to egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index bb02dac581..133499c8bc 100755 --- a/egs/librispeech/ASR/local/convert_transcript_to_corpus.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -8,8 +8,8 @@ If the lexicon contains phones, the resulting LM will be a phone LM; If the lexicon contains word pieces, the resulting LM will be a word piece LM. -If a word has multiple pronunciations, the one that appears last in the lexicon -is used. +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. If the input transcript is: @@ -20,8 +20,8 @@ and if the lexicon is SPN - hello h e l l o hello h e l l o 2 + hello h e l l o world w o r l d zoo z o o @@ -32,10 +32,11 @@ SPN z o o w o r l d SPN """ +import argparse from pathlib import Path -from typing import Dict +from typing import Dict, List -import argparse +from generate_unique_lexicon import filter_multiple_pronunications from icefall.lexicon import read_lexicon @@ -57,7 +58,9 @@ def get_args(): return parser.parse_args() -def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None: +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: """ Args: lexicon: @@ -86,7 +89,11 @@ def main(): assert Path(args.transcript).is_file() assert len(args.oov) > 0 - lexicon = dict(read_lexicon(args.lexicon)) + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + assert args.oov in lexicon oov_token = lexicon[args.oov] diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py new file mode 100755 index 0000000000..566c0743db --- /dev/null +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index 0880019b3e..d913756a19 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -33,6 +33,7 @@ 5. Generate L_disambig.pt, in k2 format. """ +import argparse import math from collections import defaultdict from pathlib import Path @@ -42,10 +43,37 @@ import torch from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool Lexicon = List[Tuple[str, List[str]]] +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: """Write a symbol to ID mapping to a file. @@ -315,8 +343,9 @@ def lexicon_to_fst( def main(): - out_dir = Path("data/lang_phone") - lexicon_filename = out_dir / "lexicon.txt" + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 @@ -344,9 +373,9 @@ def main(): token2id = generate_id_map(tokens) word2id = generate_id_map(words) - write_mapping(out_dir / "tokens.txt", token2id) - write_mapping(out_dir / "words.txt", word2id) - write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig) + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst( lexicon, @@ -364,17 +393,20 @@ def main(): sil_prob=sil_prob, need_self_loops=True, ) - torch.save(L.as_dict(), out_dir / "L.pt") - torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt") - - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index 39d347661a..cf32f308df 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -49,6 +49,8 @@ write_mapping, ) +from icefall.utils import str2bool + def lexicon_to_fst_no_sil( lexicon: Lexicon, @@ -169,6 +171,20 @@ def get_args(): """, ) + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + return parser.parse_args() @@ -221,6 +237,18 @@ def main(): torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 3c3ecdcae8..bc5812810e 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + # You can install sentencepiece via: # # pip install sentencepiece @@ -37,10 +38,17 @@ def get_args(): "--lang-dir", type=str, help="""Input and output directory. - It should contain the training corpus: train.txt. + It should contain the training corpus: transcript_words.txt. The generated bpe.model is saved to this directory. """, ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + parser.add_argument( "--vocab-size", type=int, @@ -58,7 +66,7 @@ def main(): model_type = "unigram" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = f"{lang_dir}/train.txt" + train_text = args.transcript character_coverage = 1.0 input_sentence_size = 100000000 diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 564f0d067b..1965dc491c 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -40,9 +40,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + # 5000 + # 2000 + # 1000 500 ) @@ -116,14 +116,15 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" - mkdir -p data/lang_phone + lang_dir=data/lang_phone + mkdir -p $lang_dir (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | cat - $dl_dir/lm/librispeech-lexicon.txt | - sort | uniq > data/lang_phone/lexicon.txt + sort | uniq > $lang_dir/lexicon.txt - if [ ! -f data/lang_phone/L_disambig.pt ]; then - ./local/prepare_lang.py + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir fi fi @@ -138,7 +139,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then # so that the two can share G.pt later. cp data/lang_phone/words.txt $lang_dir - if [ ! -f $lang_dir/train.txt ]; then + if [ ! -f $lang_dir/transcript_words.txt ]; then log "Generate data for BPE training" files=$( find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" @@ -147,12 +148,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then ) for f in ${files[@]}; do cat $f | cut -d " " -f 2- - done > $lang_dir/train.txt + done > $lang_dir/transcript_words.txt fi ./local/train_bpe_model.py \ --lang-dir $lang_dir \ - --vocab-size $vocab_size + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang_bpe.py --lang-dir $lang_dir @@ -166,18 +168,18 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} - if [ ! -f $lang_dir/corpus.txt ]; then - ./local/convert_transcript_to_corpus.py \ - --lexicon data/lang_bpe/lexicon.txt \ - --transcript data/lang_bpe/train.txt \ + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ --oov "" \ - > $lang_dir/corpus.txt + > $lang_dir/transcript_tokens.txt fi if [ ! -f $lang_dir/P.arpa ]; then ./shared/make_kn_lm.py \ -ngram-order 2 \ - -text $lang_dir/corpus.txt \ + -text $lang_dir/transcript_tokens.txt \ -lm $lang_dir/P.arpa fi @@ -226,4 +228,4 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then done fi -cd data && ln -sfv lang_bpe_5000 lang_bpe +cd data && ln -sfv lang_bpe_500 lang_bpe diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py index 813b15f763..e76b7ea323 100644 --- a/icefall/bpe_graph_compiler.py +++ b/icefall/bpe_graph_compiler.py @@ -34,14 +34,10 @@ def __init__( """ Args: lang_dir: - This directory is expected to contain the following files:: + This directory is expected to contain the following files: - bpe.model - words.txt - - The above files are produced by the script `prepare.sh`. You - should have run that before running the training code. - device: It indicates CPU or CUDA. sos_token: diff --git a/icefall/bpe_mmi_graph_compiler.py b/icefall/bpe_mmi_graph_compiler.py deleted file mode 100644 index 83bc9846f9..0000000000 --- a/icefall/bpe_mmi_graph_compiler.py +++ /dev/null @@ -1,178 +0,0 @@ -import logging -from pathlib import Path -from typing import List, Tuple, Union - -import k2 -import sentencepiece as spm -import torch - -from icefall.lexicon import Lexicon - - -class BpeMmiTrainingGraphCompiler(object): - def __init__( - self, - lang_dir: Path, - device: Union[str, torch.device] = "cpu", - sos_token: str = "", - eos_token: str = "", - ) -> None: - """ - Args: - lang_dir: - Path to the lang directory. It is expected to contain the - following files:: - - - tokens.txt - - words.txt - - bpe.model - - P.fst.txt - - The above files are generated by the script `prepare.sh`. You - should have run it before running the training code. - - device: - It indicates CPU or CUDA. - sos_token: - The word piece that represents sos. - eos_token: - The word piece that represents eos. - """ - self.lang_dir = Path(lang_dir) - self.lexicon = Lexicon(lang_dir) - self.device = device - self.load_sentence_piece_model() - self.build_ctc_topo_P() - - self.sos_id = self.sp.piece_to_id(sos_token) - self.eos_id = self.sp.piece_to_id(eos_token) - - assert self.sos_id != self.sp.unk_id() - assert self.eos_id != self.sp.unk_id() - - def load_sentence_piece_model(self) -> None: - """Load the pre-trained sentencepiece model - from self.lang_dir/bpe.model. - """ - model_file = self.lang_dir / "bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - self.sp = sp - - def build_ctc_topo_P(self): - """Built ctc_topo_P, the composition result of - ctc_topo and P, where P is a pre-trained bigram - word piece LM. - """ - # Note: there is no need to save a pre-compiled P and ctc_topo - # as it is very fast to generate them. - logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}") - with open(self.lang_dir / "P.fst.txt") as f: - # P is not an acceptor because there is - # a back-off state, whose incoming arcs - # have label #0 and aux_label 0 (i.e., ). - P = k2.Fsa.from_openfst(f.read(), acceptor=False) - - first_token_disambig_id = self.lexicon.token_table["#0"] - - # P.aux_labels is not needed in later computations, so - # remove it here. - del P.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - P.labels[P.labels >= first_token_disambig_id] = 0 - - P = k2.remove_epsilon(P) - P = k2.arc_sort(P) - P = P.to(self.device) - # Add epsilon self-loops to P because we want the - # following operation "k2.intersect" to run on GPU. - P_with_self_loops = k2.add_epsilon_self_loops(P) - - max_token_id = max(self.lexicon.tokens) - logging.info( - f"Building modified ctc_topo. max_token_id: {max_token_id}" - ) - # CAUTION: We have to use a modifed version of CTC topo. - # Otherwise, the resulting ctc_topo_P is so large that it gets - # stuck in k2.intersect_dense_pruned() or it gets OOM in - # k2.intersect_dense() - ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device) - - ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) - - logging.info("Building ctc_topo_P") - ctc_topo_P = k2.intersect( - ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False - ).invert() - - self.ctc_topo_P = k2.arc_sort(ctc_topo_P) - - def texts_to_ids(self, texts: List[str]) -> List[List[int]]: - """Convert a list of texts to a list-of-list of piece IDs. - - Args: - texts: - A list of transcripts. Within a transcript words are - separated by spaces. An example input is:: - - ['HELLO ICEFALL', 'HELLO k2'] - Returns: - Return a list-of-list of piece IDs. - """ - return self.sp.encode(texts, out_type=int) - - def compile( - self, texts: List[str], replicate_den: bool = True - ) -> Tuple[k2.Fsa, k2.Fsa]: - """Create numerator and denominator graphs from transcripts. - - Args: - texts: - A list of transcripts. Within a transcript words are - separated by spaces. An example input is:: - - ["HELLO icefall", "HALLO WELT"] - - replicate_den: - If True, the returned den_graph is replicated to match the number - of FSAs in the returned num_graph; if False, the returned den_graph - contains only a single FSA - Returns: - A tuple (num_graphs, den_graphs), where - - - `num_graphs` is the numerator graph. It is an FsaVec with - shape `(len(texts), None, None)`. - - - `den_graphs` is the denominator graph. It is an FsaVec with the - same shape of the `num_graph` if replicate_den is True; - otherwise, it is an FsaVec containing only a single FSA. - """ - token_ids = self.texts_to_ids(texts) - token_fsas = k2.linear_fsa(token_ids, device=self.device) - - token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas) - - # NOTE: Use treat_epsilons_specially=False so that k2.compose - # can be run on GPU - num_graphs = k2.compose( - self.ctc_topo_P, - token_fsas_with_self_loops, - treat_epsilons_specially=False, - ) - # num_graphs may not be connected and - # not be topologically sorted after k2.compose - num_graphs = k2.connect(num_graphs) - num_graphs = k2.top_sort(num_graphs) - - ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()]) - if replicate_den: - indexes = torch.zeros( - len(texts), dtype=torch.int32, device=self.device - ) - den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes) - else: - den_graphs = ctc_topo_P_vec - - return num_graphs, den_graphs diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 1378d79fb8..80bd7c1ee8 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -84,6 +84,69 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: f.write(f"{word} {' '.join(tokens)}\n") +def convert_lexicon_to_ragged( + filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable +) -> k2.RaggedTensor: + """Read a lexicon and convert it to a ragged tensor. + + The ragged tensor has two axes: [word][token]. + + Caution: + We assume that each word has a unique pronunciation. + + Args: + filename: + Filename of the lexicon. It has a format that can be read + by :func:`read_lexicon`. + word_table: + The word symbol table. + token_table: + The token symbol table. + Returns: + A k2 ragged tensor with two axes [word][token]. + """ + disambig_id = word_table["#0"] + # We reuse the same words.txt from the phone based lexicon + # so that we can share the same G.fst. Here, we have to + # exclude some words present only in the phone based lexicon. + excluded_words = ["", "!SIL", ""] + + # epsilon is not a word, but it occupies a position + # + row_splits = [0] + token_ids_list = [] + + lexicon_tmp = read_lexicon(filename) + lexicon = dict(lexicon_tmp) + if len(lexicon_tmp) != len(lexicon): + raise RuntimeError( + "It's assumed that each word has a unique pronunciation" + ) + + for i in range(disambig_id): + w = word_table[i] + if w in excluded_words: + row_splits.append(row_splits[-1]) + continue + tokens = lexicon[w] + token_ids = [token_table[k] for k in tokens] + + row_splits.append(row_splits[-1] + len(token_ids)) + token_ids_list.extend(token_ids) + + cached_tot_size = row_splits[-1] + row_splits = torch.tensor(row_splits, dtype=torch.int32) + + shape = k2.ragged.create_ragged_shape2( + row_splits, + None, + cached_tot_size, + ) + values = torch.tensor(token_ids_list, dtype=torch.int32) + + return k2.RaggedTensor(shape, values) + + class Lexicon(object): """Phone based lexicon.""" @@ -96,12 +159,10 @@ def __init__( Args: lang_dir: Path to the lang directory. It is expected to contain the following - files:: - + files: - tokens.txt - words.txt - L.pt - The above files are produced by the script `prepare.sh`. You should have run that before running the training code. disambig_pattern: @@ -121,7 +182,7 @@ def __init__( torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") # We save L_inv instead of L because it will be used to intersect with - # transcript, both of whose labels are word IDs. + # transcript FSAs, both of whose labels are word IDs. self.L_inv = L_inv self.disambig_pattern = disambig_pattern @@ -144,69 +205,66 @@ def tokens(self) -> List[int]: return ans -class BpeLexicon(Lexicon): +class UniqLexicon(Lexicon): def __init__( self, lang_dir: Path, + uniq_filename: str = "uniq_lexicon.txt", disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Refer to the help information in Lexicon.__init__. + + uniq_filename: It is assumed to be inside the given `lang_dir`. + + Each word in the lexicon is assumed to have a unique pronunciation. """ + lang_dir = Path(lang_dir) super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern) - self.ragged_lexicon = self.convert_lexicon_to_ragged( - lang_dir / "lexicon.txt" + self.ragged_lexicon = convert_lexicon_to_ragged( + filename=lang_dir / uniq_filename, + word_table=self.word_table, + token_table=self.token_table, ) + # TODO: should we move it to a certain device ? - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: - """Read a BPE lexicon from file and convert it to a - k2 ragged tensor. - + def texts_to_token_ids( + self, texts: List[str], oov: str = "" + ) -> k2.RaggedTensor: + """ Args: - filename: - Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt + texts: + A list of transcripts. Each transcript contains space(s) + separated words. An example texts is:: + + ['HELLO k2', 'HELLO icefall'] + oov: + The OOV word. If a word in `texts` is not in the lexicon, it is + replaced with `oov`. Returns: - A k2 ragged tensor with two axes [word_id] + Return a ragged int tensor with 2 axes [utterance][token_id] """ - disambig_id = self.word_table["#0"] - # We reuse the same words.txt from the phone based lexicon - # so that we can share the same G.fst. Here, we have to - # exclude some words present only in the phone based lexicon. - excluded_words = ["", "!SIL", ""] - - # epsilon is not a word, but it occupies on position - # - row_splits = [0] - token_ids = [] - - lexicon = read_lexicon(filename) - lexicon = dict(lexicon) - - for i in range(disambig_id): - w = self.word_table[i] - if w in excluded_words: - row_splits.append(row_splits[-1]) - continue - pieces = lexicon[w] - piece_ids = [self.token_table[k] for k in pieces] - - row_splits.append(row_splits[-1] + len(piece_ids)) - token_ids.extend(piece_ids) - - cached_tot_size = row_splits[-1] - row_splits = torch.tensor(row_splits, dtype=torch.int32) - - shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=cached_tot_size - ) - values = torch.tensor(token_ids, dtype=torch.int32) + oov_id = self.word_table[oov] + + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(oov_id) + word_ids_list.append(word_ids) + ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32) + ans = self.ragged_lexicon.index(ragged_indexes) + ans = ans.remove_axis(ans.num_axes - 2) + return ans - return k2.RaggedTensor(shape, values) + def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor: + """Convert a list of words to a ragged tensor containing token IDs. - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: - """Convert a list of words to a ragged tensor contained - word piece IDs. + We assume there are no OOVs in "words". """ word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py new file mode 100644 index 0000000000..43f2a092a1 --- /dev/null +++ b/icefall/mmi_graph_compiler.py @@ -0,0 +1,216 @@ +import logging +from pathlib import Path +from typing import Iterable, List, Tuple, Union + +import k2 +import torch + +from icefall.lexicon import UniqLexicon + + +class MmiTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + uniq_filename: str = "uniq_lexicon.txt", + device: Union[str, torch.device] = "cpu", + oov: str = "", + ): + """ + Args: + lang_dir: + Path to the lang directory. It is expected to contain the + following files:: + + - tokens.txt + - words.txt + - P.fst.txt + + The above files are generated by the script `prepare.sh`. You + should have run it before running the training code. + uniq_filename: + File name to the lexicon in which every word has exactly one + pronunciation. We assume this file is inside the given `lang_dir`. + + device: + It indicates CPU or CUDA. + oov: + Out of vocabulary word. When a word in the transcript + does not exist in the lexicon, it is replaced with `oov`. + """ + self.lang_dir = Path(lang_dir) + self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename) + self.device = torch.device(device) + + self.L_inv = self.lexicon.L_inv.to(self.device) + + self.oov_id = self.lexicon.word_table[oov] + + self.build_ctc_topo_P() + + def build_ctc_topo_P(self): + """Built ctc_topo_P, the composition result of + ctc_topo and P, where P is a pre-trained bigram + word piece LM. + """ + # Note: there is no need to save a pre-compiled P and ctc_topo + # as it is very fast to generate them. + logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}") + with open(self.lang_dir / "P.fst.txt") as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label 0 (i.e., ). + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + + first_token_disambig_id = self.lexicon.token_table["#0"] + + # P.aux_labels is not needed in later computations, so + # remove it here. + del P.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + P.labels[P.labels >= first_token_disambig_id] = 0 + + P = k2.remove_epsilon(P) + P = k2.arc_sort(P) + P = P.to(self.device) + # Add epsilon self-loops to P because we want the + # following operation "k2.intersect" to run on GPU. + P_with_self_loops = k2.add_epsilon_self_loops(P) + + max_token_id = max(self.lexicon.tokens) + logging.info( + f"Building ctc_topo (modified=False). max_token_id: {max_token_id}" + ) + ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device) + + ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + + logging.info("Building ctc_topo_P") + ctc_topo_P = k2.intersect( + ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False + ).invert() + + self.ctc_topo_P = k2.arc_sort(ctc_topo_P) + + def compile( + self, texts: Iterable[str], replicate_den: bool = True + ) -> Tuple[k2.Fsa, k2.Fsa]: + """Create numerator and denominator graphs from transcripts + and the bigram phone LM. + + Args: + texts: + A list of transcripts. Within a transcript, words are + separated by spaces. An example `texts` is given below:: + + ["Hello icefall", "LF-MMI training with icefall using k2"] + + replicate_den: + If True, the returned den_graph is replicated to match the number + of FSAs in the returned num_graph; if False, the returned den_graph + contains only a single FSA + Returns: + A tuple (num_graph, den_graph), where + + - `num_graph` is the numerator graph. It is an FsaVec with + shape `(len(texts), None, None)`. + + - `den_graph` is the denominator graph. It is an FsaVec + with the same shape of the `num_graph` if replicate_den is + True; otherwise, it is an FsaVec containing only a single FSA. + """ + transcript_fsa = self.build_transcript_fsa(texts) + + # remove word IDs from transcript_fsa since it is not needed + del transcript_fsa.aux_labels + # NOTE: You can comment out the above statement + # if you want to run test/test_mmi_graph_compiler.py + + transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops( + transcript_fsa + ) + + transcript_fsa_with_self_loops = k2.arc_sort( + transcript_fsa_with_self_loops + ) + + num = k2.compose( + self.ctc_topo_P, + transcript_fsa_with_self_loops, + treat_epsilons_specially=False, + ) + + # CAUTION: Due to the presence of P, + # the resulting `num` may not be connected + num = k2.connect(num) + + num = k2.arc_sort(num) + + ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P]) + if replicate_den: + indexes = torch.zeros( + len(texts), dtype=torch.int32, device=self.device + ) + den = k2.index_fsa(ctc_topo_P_vec, indexes) + else: + den = ctc_topo_P_vec + + return num, den + + def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa: + """Convert transcripts to an FsaVec with the help of a lexicon + and word symbol table. + + Args: + texts: + Each element is a transcript containing words separated by space(s). + For instance, it may be 'HELLO icefall', which contains + two words. + + Returns: + Return an FST (FsaVec) corresponding to the transcript. + Its `labels` is token IDs and `aux_labels` is word IDs. + """ + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.lexicon.word_table: + word_ids.append(self.lexicon.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + + fsa = k2.linear_fsa(word_ids_list, self.device) + fsa = k2.add_epsilon_self_loops(fsa) + + # The reason to use `invert_()` at the end is as follows: + # + # (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs + # (2) `fsa.labels` is word IDs + # (3) after intersection, the `labels` is still word IDs + # (4) after `invert_()`, the `labels` is token IDs + # and `aux_labels` is word IDs + transcript_fsa = k2.intersect( + self.L_inv, fsa, treat_epsilons_specially=False + ).invert_() + transcript_fsa = k2.arc_sort(transcript_fsa) + return transcript_fsa + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + We assume it contains no OOVs. Otherwise, it will raise an + exception. + Returns: + Return a list-of-list of token IDs. + """ + return self.lexicon.texts_to_token_ids(texts).tolist() diff --git a/icefall/utils.py b/icefall/utils.py index 23b4dd6c76..1c4dceb0be 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,14 +19,16 @@ import logging import os import subprocess +import sys from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union import k2 import kaldialign +import lhotse import torch import torch.distributed as dist @@ -132,17 +134,82 @@ def setup_logger( logging.getLogger("").addHandler(console) -def get_env_info(): - """ - TODO: - """ +def get_git_sha1(): + git_commit = ( + subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + dirty_commit = ( + len( + subprocess.run( + ["git", "diff", "--shortstat"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + > 0 + ) + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) + return git_commit + + +def get_git_date(): + git_date = ( + subprocess.run( + ["git", "log", "-1", "--format=%ad", "--date=local"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_git_branch_name(): + git_date = ( + subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_env_info() -> Dict[str, Any]: + """Get the environment information.""" return { - "k2-git-sha1": None, - "k2-version": None, - "lhotse-version": None, - "torch-version": None, - "icefall-sha1": None, - "icefall-version": None, + "k2-version": k2.version.__version__, + "k2-build-type": k2.version.__build_type__, + "k2-with-cuda": k2.with_cuda, + "k2-git-sha1": k2.version.__git_sha1__, + "k2-git-date": k2.version.__git_date__, + "lhotse-version": lhotse.__version__, + "torch-cuda-available": torch.cuda.is_available(), + "torch-cuda-version": torch.version.cuda, + "python-version": sys.version[:3], + "icefall-git-branch": get_git_branch_name(), + "icefall-git-sha1": get_git_sha1(), + "icefall-git-date": get_git_date(), + "icefall-path": str(Path(__file__).resolve().parent.parent), + "k2-path": str(Path(k2.__file__).resolve()), + "lhotse-path": str(Path(lhotse.__file__).resolve()), } diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py index e58c4f1c63..6c9073c4cf 100755 --- a/test/test_bpe_graph_compiler.py +++ b/test/test_bpe_graph_compiler.py @@ -19,20 +19,21 @@ from pathlib import Path from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.lexicon import BpeLexicon +from icefall.lexicon import UniqLexicon + +ICEFALL_DIR = Path(__file__).resolve().parent.parent def test(): - lang_dir = Path("data/lang/bpe") + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe" if not lang_dir.is_dir(): return - # TODO: generate data for testing compiler = BpeCtcTrainingGraphCompiler(lang_dir) ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"]) compiler.compile(ids) - lexicon = BpeLexicon(lang_dir) + lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt") ids0 = lexicon.words_to_piece_ids(["HELLO"]) assert ids[0] == ids0.values().tolist() diff --git a/test/test_bpe_mmi_graph_compiler.py b/test/test_bpe_mmi_graph_compiler.py deleted file mode 100644 index c6009d69b3..0000000000 --- a/test/test_bpe_mmi_graph_compiler.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python3 - -import copy -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler - - -def test_bpe_mmi_graph_compiler(): - lang_dir = Path("data/lang_bpe") - if lang_dir.is_dir() is False: - return - device = torch.device("cpu") - compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device) - - texts = ["HELLO WORLD", "MMI TRAINING"] - - num_graphs, den_graphs = compiler.compile(texts) - num_graphs.labels_sym = compiler.lexicon.token_table - num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table) - num_graphs.aux_labels_sym._id2sym[0] = "" - num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD") - num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD") - print(den_graphs.shape) - print(den_graphs[0].shape) - print(den_graphs[0].num_arcs) diff --git a/test/test_lexicon.py b/test/test_lexicon.py old mode 100644 new mode 100755 index 6801b3a89a..2a16db2260 --- a/test/test_lexicon.py +++ b/test/test_lexicon.py @@ -14,80 +14,135 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +You can run this file in one of the two ways: + (1) cd icefall; pytest test/test_lexicon.py + (2) cd icefall; ./test/test_lexicon.py +""" + +import os +import shutil +import sys from pathlib import Path import k2 -import pytest -import torch - -from icefall.lexicon import BpeLexicon, Lexicon - - -@pytest.fixture -def lang_dir(tmp_path): - phone2id = """ - 0 - a 1 - b 2 - f 3 - o 4 - r 5 - z 6 - SPN 7 - #0 8 - """ - word2id = """ - 0 - foo 1 - bar 2 - baz 3 - 4 - #0 5 +import sentencepiece as spm + +from icefall.lexicon import UniqLexicon + +TMP_DIR = "/tmp/icefall-test-lexicon" +USING_PYTEST = "pytest" in sys.modules +ICEFALL_DIR = Path(__file__).resolve().parent.parent + + +def generate_test_data(): + Path(TMP_DIR).mkdir(exist_ok=True) + sentences = """ +cat tac cat cat +at +tac at ta at at +at cat ct ct ta +cat cat cat cat +at at at at at at at """ - L = k2.Fsa.from_str( - """ - 0 0 7 4 0 - 0 7 -1 -1 0 - 0 1 3 1 0 - 0 3 2 2 0 - 0 5 2 3 0 - 1 2 4 0 0 - 2 0 4 0 0 - 3 4 1 0 0 - 4 0 5 0 0 - 5 6 1 0 0 - 6 0 6 0 0 - 7 - """, - num_aux_labels=1, + transcript = Path(TMP_DIR) / "transcript_words.txt" + with open(transcript, "w") as f: + for line in sentences.strip().split("\n"): + f.write(f"{line}\n") + + words = """ + 0 + 1 +at 2 +cat 3 +ct 4 +ta 5 +tac 6 +#0 7 + 8 + 9 +""" + word_txt = Path(TMP_DIR) / "words.txt" + with open(word_txt, "w") as f: + for line in words.strip().split("\n"): + f.write(f"{line}\n") + + vocab_size = 8 + + os.system( + f""" +cd {ICEFALL_DIR}/egs/librispeech/ASR + +./local/train_bpe_model.py \ + --lang-dir {TMP_DIR} \ + --vocab-size {vocab_size} \ + --transcript {transcript} + +./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 1 +""" ) - with open(tmp_path / "tokens.txt", "w") as f: - f.write(phone2id) - with open(tmp_path / "words.txt", "w") as f: - f.write(word2id) - torch.save(L.as_dict(), tmp_path / "L.pt") +def delete_test_data(): + shutil.rmtree(TMP_DIR) + + +def uniq_lexicon_test(): + lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="lexicon.txt") + + # case 1: No OOV + texts = ["cat cat", "at ct", "at tac cat"] + token_ids = lexicon.texts_to_token_ids(texts) + + sp = spm.SentencePieceProcessor() + sp.load(f"{TMP_DIR}/bpe.model") + + expected_token_ids: List[List[int]] = sp.encode(texts, out_type=int) + assert token_ids.tolist() == expected_token_ids + + # case 2: With OOV + texts = ["ca"] + token_ids = lexicon.texts_to_token_ids(texts) + expected_token_ids = sp.encode(texts, out_type=int) + assert token_ids.tolist() != expected_token_ids + # Note: sentencepiece breaks "ca" into "_ c a" + # But there is no word "ca" in the lexicon, so our + # implementation returns the id of "" + print(token_ids, expected_token_ids) + assert token_ids.tolist() == [[sp.unk_id()]] + + # case 3: With OOV + texts = ["foo"] + token_ids = lexicon.texts_to_token_ids(texts) + expected_token_ids = sp.encode(texts, out_type=int) + print(token_ids) + print(expected_token_ids) + + # test ragged lexicon + ragged_lexicon = lexicon.ragged_lexicon.tolist() + word_disambig_id = lexicon.word_table["#0"] + for i in range(2, word_disambig_id): + piece_id = ragged_lexicon[i] + word = lexicon.word_table[i] + assert word == sp.decode(piece_id) + assert piece_id == sp.encode(word) + + +def test_main(): + generate_test_data() - return tmp_path + uniq_lexicon_test() + if USING_PYTEST: + delete_test_data() -def test_lexicon(lang_dir): - lexicon = Lexicon(lang_dir) - assert lexicon.tokens == list(range(1, 8)) +def main(): + test_main() -def test_bpe_lexicon(): - lang_dir = Path("data/lang/bpe") - if not lang_dir.is_dir(): - return - # TODO: Generate test data for BpeLexicon - lexicon = BpeLexicon(lang_dir) - words = ["", "HELLO", "ZZZZ", "WORLD"] - ids = lexicon.words_to_piece_ids(words) - print(ids) - print([lexicon.token_table[i] for i in ids.values().tolist()]) +if __name__ == "__main__" and not USING_PYTEST: + main() diff --git a/test/test_mmi_graph_compiler.py b/test/test_mmi_graph_compiler.py new file mode 100755 index 0000000000..653c57b591 --- /dev/null +++ b/test/test_mmi_graph_compiler.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +You can run this file in one of the two ways: + + (1) cd icefall; pytest test/test_mmi_graph_compiler.py + (2) cd icefall; ./test/test_mmi_graph_compiler.py +""" + +import copy +import os +import shutil +import sys +from pathlib import Path + +import k2 +import sentencepiece as spm + +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler + +TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler" +USING_PYTEST = "pytest" in sys.modules +ICEFALL_DIR = Path(__file__).resolve().parent.parent + + +def generate_test_data(): + Path(TMP_DIR).mkdir(exist_ok=True) + sentences = """ +cat tac cat cat +at at cat at cat cat +tac at ta at at +at cat ct ct ta ct ct cat tac +cat cat cat cat +at at at at at at at + """ + + transcript = Path(TMP_DIR) / "transcript_words.txt" + with open(transcript, "w") as f: + for line in sentences.strip().split("\n"): + f.write(f"{line}\n") + + words = """ + 0 + 1 +at 2 +cat 3 +ct 4 +ta 5 +tac 6 +#0 7 + 8 + 9 +""" + word_txt = Path(TMP_DIR) / "words.txt" + with open(word_txt, "w") as f: + for line in words.strip().split("\n"): + f.write(f"{line}\n") + + vocab_size = 8 + + os.system( + f""" +cd {ICEFALL_DIR}/egs/librispeech/ASR + +./local/train_bpe_model.py \ + --lang-dir {TMP_DIR} \ + --vocab-size {vocab_size} \ + --transcript {transcript} + +./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 0 + +./local/convert_transcript_words_to_tokens.py \ +--lexicon {TMP_DIR}/lexicon.txt \ +--transcript {transcript} \ +--oov "" \ +> {TMP_DIR}/transcript_tokens.txt + +./shared/make_kn_lm.py \ +-ngram-order 2 \ +-text {TMP_DIR}/transcript_tokens.txt \ +-lm {TMP_DIR}/P.arpa + +python3 -m kaldilm \ +--read-symbol-table="{TMP_DIR}/tokens.txt" \ +--disambig-symbol='#0' \ +--max-order=2 \ +{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt +""" + ) + + +def delete_test_data(): + shutil.rmtree(TMP_DIR) + + +def mmi_graph_compiler_test(): + # Caution: + # You have to uncomment + # del transcript_fsa.aux_labels + # in mmi_graph_compiler.py + # to see the correct aux_labels in *.svg + graph_compiler = MmiTrainingGraphCompiler( + lang_dir=TMP_DIR, uniq_filename="lexicon.txt" + ) + print(graph_compiler.device) + L_inv = graph_compiler.L_inv + L = k2.invert(L_inv) + + L.labels_sym = graph_compiler.lexicon.token_table + L.aux_labels_sym = graph_compiler.lexicon.word_table + L.draw(f"{TMP_DIR}/L.svg", title="L") + + L_inv.labels_sym = graph_compiler.lexicon.word_table + L_inv.aux_labels_sym = graph_compiler.lexicon.token_table + L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L") + + ctc_topo_P = graph_compiler.ctc_topo_P + ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table) + ctc_topo_P.labels_sym._id2sym[0] = "" + ctc_topo_P.labels_sym._sym2id[""] = 0 + ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table + ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P") + + print(ctc_topo_P.num_arcs) + print(k2.connect(ctc_topo_P).num_arcs) + + with open(str(TMP_DIR) + "/P.fst.txt") as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label 0 (i.e., ). + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + P.labels_sym = graph_compiler.lexicon.token_table + P.aux_labels_sym = graph_compiler.lexicon.token_table + P.draw(f"{TMP_DIR}/P.svg", title="P") + + ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False) + ctc_topo.labels_sym = ctc_topo_P.labels_sym + ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table + ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo") + print("p num arcs", P.num_arcs) + print("ctc_topo num arcs", ctc_topo.num_arcs) + print("ctc_topo_P num arcs", ctc_topo_P.num_arcs) + + texts = ["cat at ct", "at ta", "cat tac"] + transcript_fsa = graph_compiler.build_transcript_fsa(texts) + transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ct.svg", title="cat_at_ct") + transcript_fsa[1].draw(f"{TMP_DIR}/at_ta.svg", title="at_ta") + transcript_fsa[2].draw(f"{TMP_DIR}/cat_tac.svg", title="cat_tac") + + num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) + num_graphs[0].draw(f"{TMP_DIR}/num_cat_at_ct.svg", title="num_cat_at_ct") + num_graphs[1].draw(f"{TMP_DIR}/num_at_ta.svg", title="num_at_ta") + num_graphs[2].draw(f"{TMP_DIR}/num_cat_tac.svg", title="num_cat_tac") + + den_graphs[0].draw(f"{TMP_DIR}/den_cat_at_ct.svg", title="den_cat_at_ct") + den_graphs[2].draw(f"{TMP_DIR}/den_cat_tac.svg", title="den_cat_tac") + + sp = spm.SentencePieceProcessor() + sp.load(f"{TMP_DIR}/bpe.model") + + texts = ["cat at cat", "at tac"] + token_ids = graph_compiler.texts_to_ids(texts) + expected_token_ids = sp.encode(texts) + assert token_ids == expected_token_ids + + +def test_main(): + generate_test_data() + + mmi_graph_compiler_test() + + if USING_PYTEST: + delete_test_data() + + +def main(): + test_main() + + +if __name__ == "__main__" and not USING_PYTEST: + main() From 9e6bd0f07c29b467618cde424fedefe81faf3e6f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 24 Sep 2021 20:34:34 +0800 Subject: [PATCH 11/19] Minor fixes. --- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- egs/librispeech/ASR/conformer_mmi/train.py | 245 +++++++++------------ egs/librispeech/ASR/prepare.sh | 2 - icefall/mmi.py | 10 +- icefall/mmi_graph_compiler.py | 5 + 5 files changed, 117 insertions(+), 147 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a7..8c1fc9595b 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -99,7 +99,7 @@ def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline - is saved in the variable `params`. + are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index f11291bbf8..6decbc1891 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -1,4 +1,21 @@ #!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import logging @@ -11,21 +28,20 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed -from tdnn_lstm_ctc.model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, encode_supervisions, @@ -61,28 +77,22 @@ def get_parser(): ) parser.add_argument( - "--use-ali-model", - type=str2bool, - default=True, - help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py " - "and you have some checkpoints inside the directory " - "tdnn_lstm_ctc/exp_bpe_500 ." - "It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt " - "as the pre-trained alignment model", + "--num-epochs", + type=int, + default=50, + help="Number of epochs to train.", ) + parser.add_argument( - "--ali-model-epoch", + "--start-epoch", type=int, - default=19, - help="If --use-ali-model is True, load " - "tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as " - "the alignment model." - "Used only if --use-ali-model is True.", + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_mmi/exp/epoch-{start_epoch-1}.pt + """, ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -90,7 +100,7 @@ def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline - is saved in the variable `params`. + are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. @@ -103,20 +113,6 @@ def get_params() -> AttributeDict: - lang_dir: It contains language related input files such as "lexicon.txt" - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - num_epochs: Number of epochs to train. - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -135,36 +131,60 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 - - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - weight_decay: The weight_decay for the optimizer. + + - lr_factor: The lr_factor for Noam optimizer. + + - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( { "exp_dir": Path("conformer_mmi/exp_500"), "lang_dir": Path("data/lang_bpe_500"), - "feature_dim": 80, - "weight_decay": 1e-6, - "subsampling_factor": 4, - "start_epoch": 0, - "num_epochs": 50, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, - "valid_interval": 10, - "use_pruned_intersect": False, - "den_scale": 1.0, - # - "att_rate": 0.7, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "is_espnet_structure": True, - "use_feat_batchnorm": True, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + "att_rate": 0.7, + # parameters for Noam + "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, + "use_pruned_intersect": False, + "den_scale": 1.0, } ) @@ -261,13 +281,12 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: nn.Module, - ali_model: Optional[nn.Module], batch: dict, - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, is_training: bool, ): """ - Compute MMI loss given the model and its inputs. + Compute LF-MMI loss given the model and its inputs. Args: params: @@ -278,7 +297,9 @@ def compute_loss( A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. graph_compiler: - It is used to build num_graphs and den_graphs. + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. is_training: True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it @@ -286,54 +307,34 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] - if ali_model is not None and params.batch_idx_train < 4000: - feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T] - ali_model_output = ali_model(feature) - # subsampling is done slightly differently, may be small length - # differences. - min_len = min(ali_model_output.shape[1], nnet_output.shape[1]) - # scale less than one so it will be encouraged - # to mimic ali_model's output - ali_model_scale = 500.0 / (params.batch_idx_train + 500) - - # Use clone() here or log-softmax backprop will fail. - nnet_output = nnet_output.clone() - - nnet_output[:, :min_len, :] += ( - ali_model_scale * ali_model_output[:, :min_len, :] - ) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in LFMMILoss - # - # TODO: If params.use_pruned_intersect is True, there is no - # need to call encode_supervisions - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) + # nnet_output is (N, T, C) - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `LFMMILoss.forward()` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) - loss_fn = LFMMILoss( - graph_compiler=graph_compiler, - den_scale=params.den_scale, - use_pruned_intersect=params.use_pruned_intersect, - ) + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + use_pruned_intersect=params.use_pruned_intersect, + den_scale=params.den_scale, + ) - mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) if params.att_rate != 0.0: token_ids = graph_compiler.texts_to_ids(texts) @@ -373,8 +374,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: nn.Module, - ali_model: Optional[nn.Module], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> None: @@ -391,7 +391,6 @@ def compute_validation_loss( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, - ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=False, @@ -432,9 +431,8 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: nn.Module, - ali_model: Optional[nn.Module], optimizer: torch.optim.Optimizer, - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, @@ -451,9 +449,6 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. - ali_model: - The force alignment model for training. It is from - tdnn_lstm_ctc/train_bpe.py optimizer: The optimizer we are using. graph_compiler: @@ -483,7 +478,6 @@ def train_one_epoch( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, - ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=True, @@ -494,7 +488,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() @@ -519,7 +513,7 @@ def train_one_epoch( f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, " f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, " + f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, " f"total avg att loss: {tot_avg_att_loss:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, " f"batch size: {batch_size}" @@ -568,7 +562,6 @@ def train_one_epoch( compute_validation_loss( params=params, model=model, - ali_model=ali_model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, @@ -576,10 +569,10 @@ def train_one_epoch( model.train() logging.info( f"Epoch {params.cur_epoch}, " - f"valid mmi loss {params.valid_mmi_loss:.4f}, " - f"valid att loss {params.valid_att_loss:.4f}, " - f"valid loss {params.valid_loss:.4f}, " - f"best valid loss: {params.best_valid_loss:.4f}, " + f"valid mmi loss {params.valid_mmi_loss:.4f}," + f"valid att loss {params.valid_att_loss:.4f}," + f"valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) if tb_writer is not None: @@ -642,11 +635,13 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - graph_compiler = BpeMmiTrainingGraphCompiler( + graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, + uniq_filename="lexicon.txt", device=device, - sos_token="", - eos_token="", + oov="", + sos_id=1, + eos_id=1, ) logging.info("About to create model") @@ -658,7 +653,6 @@ def run(rank, world_size, args): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, - is_espnet_structure=params.is_espnet_structure, use_feat_batchnorm=params.use_feat_batchnorm, ) @@ -679,32 +673,6 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) - if args.use_ali_model: - ali_model = TdnnLstm( - num_features=params.feature_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - ) - - ali_model_fname = Path( - f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt" - ) - assert ( - ali_model_fname.is_file() - ), f"ali model filename {ali_model_fname} does not exist!" - - ali_model.load_state_dict( - torch.load(ali_model_fname, map_location="cpu")["model"] - ) - ali_model.to(device) - - ali_model.eval() - ali_model.requires_grad_(False) - logging.info(f"Use ali_model: {ali_model_fname}") - else: - ali_model = None - logging.info("No ali_model") - librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() @@ -727,7 +695,6 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, - ali_model=ali_model, optimizer=optimizer, graph_compiler=graph_compiler, train_dl=train_dl, diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 1965dc491c..c1a532fc17 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -227,5 +227,3 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_500 lang_bpe diff --git a/icefall/mmi.py b/icefall/mmi.py index ec5d07dfeb..f9ba46df97 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -4,13 +4,13 @@ import torch from torch import nn -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler def _compute_mmi_loss_exact_optimized( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -98,7 +98,7 @@ def _compute_mmi_loss_exact_optimized( def _compute_mmi_loss_exact_non_optimized( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -133,7 +133,7 @@ def _compute_mmi_loss_exact_non_optimized( def _compute_mmi_loss_pruned( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -184,7 +184,7 @@ class LFMMILoss(nn.Module): def __init__( self, - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool = False, den_scale: float = 1.0, ): diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 43f2a092a1..0d901227d0 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -15,6 +15,8 @@ def __init__( uniq_filename: str = "uniq_lexicon.txt", device: Union[str, torch.device] = "cpu", oov: str = "", + sos_id: int = 1, + eos_id: int = 1, ): """ Args: @@ -45,6 +47,8 @@ def __init__( self.L_inv = self.lexicon.L_inv.to(self.device) self.oov_id = self.lexicon.word_table[oov] + self.sos_id = sos_id + self.eos_id = eos_id self.build_ctc_topo_P() @@ -93,6 +97,7 @@ def build_ctc_topo_P(self): ).invert() self.ctc_topo_P = k2.arc_sort(ctc_topo_P) + logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}") def compile( self, texts: Iterable[str], replicate_den: bool = True From 94daaee6ba76796743b0b2d9e0731a7a1812b71a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Sep 2021 15:37:47 +0800 Subject: [PATCH 12/19] Use pre-computed alignments in LF-MMI training. --- .../ASR/conformer_mmi/asr_datamodule.py | 4 +- egs/librispeech/ASR/conformer_mmi/train.py | 79 ++++++- icefall/ali.py | 142 +++++++++++ test/test_ali.py | 223 ++++++++++++++++++ 4 files changed, 446 insertions(+), 2 deletions(-) create mode 100644 icefall/ali.py create mode 100755 test/test_ali.py diff --git a/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py index 8290e71d13..d3eab87a9c 100644 --- a/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py +++ b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py @@ -162,7 +162,9 @@ def train_dataloaders(self) -> DataLoader: cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [ + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 6decbc1891..900d109a82 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -21,7 +21,7 @@ import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Dict, Optional import k2 import torch @@ -36,6 +36,11 @@ from torch.utils.tensorboard import SummaryWriter from transformer import Noam +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist @@ -93,6 +98,17 @@ def get_parser(): """, ) + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali_500", + help="""This folder is expected to contain + two files, train-960.pt and valid.pt, which + contain framewise alignment information for + the training set and validation set. + """, + ) + return parser @@ -284,6 +300,7 @@ def compute_loss( batch: dict, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, + ali: Optional[Dict[str, torch.Tensor]], ): """ Compute LF-MMI loss given the model and its inputs. @@ -304,6 +321,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + ali: + Precomputed alignments. """ device = graph_compiler.device feature = batch["inputs"] @@ -323,6 +342,30 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) + if ali is not None and params.batch_idx_train < 4000: + cut_ids = [cut.id for cut in supervisions["cut"]] + + # As encode_supervisions reorders cuts, we need + # also to reorder cut IDs here + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] + + # Check that new2old is just a permutation, + # i.e., each cut contains only one utterance + new2old.sort() + assert new2old == torch.arange(len(new2old)).tolist() + mask = lookup_alignments( + cut_ids=cut_ids, + alignments=ali, + num_classes=nnet_output.shape[2], + ).to(nnet_output) + + min_len = min(nnet_output.shape[1], mask.shape[1]) + ali_scale = 500.0 / (params.batch_idx_train + 500) + + nnet_output = nnet_output.clone() + nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] + loss_fn = LFMMILoss( graph_compiler=graph_compiler, use_pruned_intersect=params.use_pruned_intersect, @@ -377,6 +420,7 @@ def compute_validation_loss( graph_compiler: MmiTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, + ali: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Run the validation process. The validation loss is saved in `params.valid_loss`. @@ -394,6 +438,7 @@ def compute_validation_loss( batch=batch, graph_compiler=graph_compiler, is_training=False, + ali=ali, ) assert loss.requires_grad is False assert mmi_loss.requires_grad is False @@ -435,6 +480,8 @@ def train_one_epoch( graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + train_ali: Optional[Dict[str, torch.Tensor]], + valid_ali: Optional[Dict[str, torch.Tensor]], tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, ) -> None: @@ -457,6 +504,10 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + train_ali: + Precomputed alignments for the training set. + valid_ali: + Precomputed alignments for the validation set. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -481,6 +532,7 @@ def train_one_epoch( batch=batch, graph_compiler=graph_compiler, is_training=True, + ali=train_ali, ) # NOTE: We use reduction==sum and loss is computed over utterances @@ -565,6 +617,7 @@ def train_one_epoch( graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, + ali=valid_ali, ) model.train() logging.info( @@ -673,12 +726,34 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) + train_960_ali_filename = Path(params.ali_dir) / "train-960.pt" + if params.batch_idx_train < 4000 and train_960_ali_filename.is_file(): + logging.info("Use pre-computed alignments") + subsampling_factor, train_ali = load_alignments(train_960_ali_filename) + assert subsampling_factor == params.subsampling_factor + assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723" + + valid_ali_filename = Path(params.ali_dir) / "valid.pt" + subsampling_factor, valid_ali = load_alignments(valid_ali_filename) + assert subsampling_factor == params.subsampling_factor + + train_ali = convert_alignments_to_tensor(train_ali, device=device) + valid_ali = convert_alignments_to_tensor(valid_ali, device=device) + else: + logging.info("Not using alignments") + train_ali = None + valid_ali = None + librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) + if params.batch_idx_train > 4000 and train_ali is not None: + # Delete the alignments to save memory + train_ali = None + valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: @@ -699,6 +774,8 @@ def run(rank, world_size, args): graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, + train_ali=train_ali, + valid_ali=valid_ali, tb_writer=tb_writer, world_size=world_size, ) diff --git a/icefall/ali.py b/icefall/ali.py new file mode 100644 index 0000000000..c3e4b26624 --- /dev/null +++ b/icefall/ali.py @@ -0,0 +1,142 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def save_alignments( + alignments: Dict[str, List[int]], + subsampling_factor: int, + filename: str, +) -> None: + """Save alignments to a file. + + Args: + alignments: + A dict containing alignments. Keys of the dict are utterances and + values are the corresponding framewise alignments after subsampling. + subsampling_factor: + The subsampling factor of the model. + filename: + Path to save the alignments. + Returns: + Return None. + """ + ali_dict = { + "subsampling_factor": subsampling_factor, + "alignments": alignments, + } + torch.save(ali_dict, filename) + + +def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: + """Load alignments from a file. + + Args: + filename: + Path to the file containing alignment information. + The file should be saved by :func:`save_alignments`. + Returns: + Return a tuple containing: + - subsampling_factor: The subsampling_factor used to compute + the alignments. + - alignments: A dict containing utterances and their corresponding + framewise alignment, after subsampling. + """ + ali_dict = torch.load(filename) + subsampling_factor = ali_dict["subsampling_factor"] + alignments = ali_dict["alignments"] + return subsampling_factor, alignments + + +def convert_alignments_to_tensor( + alignments: Dict[str, List[int]], device: torch.device +) -> Dict[str, torch.Tensor]: + """Convert alignments from list of int to a 1-D torch.Tensor. + + Args: + alignments: + A dict containing alignments. Keys are utterance IDs and + values are their corresponding frame-wise alignments. + device: + The device to move the alignments to. + Returns: + Return a dict using 1-D torch.Tensor to store the alignments. + The dtype of the tensor are `torch.int64`. We choose `torch.int64` + because `torch.nn.functional.one_hot` requires that. + """ + ans = {} + for utt_id, ali in alignments.items(): + ali = torch.tensor(ali, dtype=torch.int64, device=device) + ans[utt_id] = ali + return ans + + +def lookup_alignments( + cut_ids: List[str], + alignments: Dict[str, torch.Tensor], + num_classes: int, + log_score: float = -10, +) -> torch.Tensor: + """Return a mask constructed from alignments by a list of cut IDs. + + The returned mask is a 3-D tensor of shape (N, T, C). For each frame, + i.e., each row, of the returned mask, positions not corresponding to + the alignments are filled with `log_score`, while the position + specified by the alignment is filled with 0. For instance, if the alignments + of two utterances are: + + [ [1, 3, 2], [1, 0, 4, 2] ] + num_classes is 5 and log_score is -10, then the returned mask is + + [ + [[-10, 0, -10, -10, -10], + [-10, -10, -10, 0, -10], + [-10, -10, 0, -10, -10], + [0, -10, -10, -10, -10]], + [[-10, 0, -10, -10, -10], + [0, -10, -10, -10, -10], + [-10, -10, -10, -10, 0], + [-10, -10, 0, -10, -10]] + ] + Note: We pad the alignment of the first utterance with 0. + + Args: + cut_ids: + A list of utterance IDs. + alignments: + A dict containing alignments. The keys are utterance IDs and the values + are framewise alignments. + num_classes: + The max token ID + 1 that appears in the alignments. + log_score: + Positions in the returned tensor not corresponding to the alignments + are filled with this value. + Returns: + Return a 3-D torch.float32 tensor of shape (N, T, C). + """ + # We assume all utterances have their alignments. + ali = [alignments[cut_id] for cut_id in cut_ids] + padded_ali = pad_sequence(ali, batch_first=True, padding_value=0) + padded_one_hot = torch.nn.functional.one_hot( + padded_ali, + num_classes=num_classes, + ) + mask = (1 - padded_one_hot) * float(log_score) + return mask diff --git a/test/test_ali.py b/test/test_ali.py new file mode 100755 index 0000000000..e8516e6dc6 --- /dev/null +++ b/test/test_ali.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Runt his file using one of the following two ways: +# (1) python3 ./test/test_ali.py +# (2) pytest ./test/test_ali.py + +# The purpose of this file is to show that if we build a mask +# from alignments and add it to a randomly generated nnet_output, +# we can decode the correct transcript. + +from pathlib import Path + +import k2 +import torch +from lhotse import load_manifest +from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader + +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import get_texts + +ICEFALL_DIR = Path(__file__).resolve().parent.parent +egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" +lang_dir = egs_dir / "data/lang_bpe_500" +# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz" +cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz" +# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz" +ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt" + +# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz" +# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt" + + +def data_exists(): + return ali_filename.exists() and cut_json.exists() and lang_dir.exists() + + +def get_dataloader(): + cuts_train = load_manifest(cut_json) + cuts_train = cuts_train.with_features_path_prefix(egs_dir) + train_sampler = SingleCutSampler( + cuts_train, + max_duration=200, + shuffle=False, + ) + + train = K2SpeechRecognitionDataset(return_cuts=True) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=1, + persistent_workers=False, + ) + return train_dl + + +def test_one_hot(): + a = [1, 3, 2] + b = [1, 0, 4, 2] + c = [torch.tensor(a), torch.tensor(b)] + d = pad_sequence(c, batch_first=True, padding_value=0) + f = torch.nn.functional.one_hot(d, num_classes=5) + e = (1 - f) * -10.0 + expected = torch.tensor( + [ + [ + [-10, 0, -10, -10, -10], + [-10, -10, -10, 0, -10], + [-10, -10, 0, -10, -10], + [0, -10, -10, -10, -10], + ], + [ + [-10, 0, -10, -10, -10], + [0, -10, -10, -10, -10], + [-10, -10, -10, -10, 0], + [-10, -10, 0, -10, -10], + ], + ] + ).to(e.dtype) + assert torch.all(torch.eq(e, expected)) + + +def test(): + """ + The purpose of this test is to show that we can use pre-computed + alignments to construct a mask, adding it to a randomly generated + nnet_output, to decode the correct transcript from the resulting + nnet_output. + """ + if not data_exists(): + return + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + dl = get_dataloader() + + subsampling_factor, ali = load_alignments(ali_filename) + ali = convert_alignments_to_tensor(ali, device=device) + + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + word_table = lexicon.word_table + + HLG = k2.Fsa.from_dict( + torch.load(f"{lang_dir}/HLG.pt", map_location=device) + ) + + for batch in dl: + features = batch["inputs"] + supervisions = batch["supervisions"] + N = features.shape[0] + T = features.shape[1] // subsampling_factor + nnet_output = ( + torch.rand(N, T, num_classes, dtype=torch.float32, device=device) + .softmax(dim=-1) + .log() + ) + cut_ids = [cut.id for cut in supervisions["cut"]] + mask = lookup_alignments( + cut_ids=cut_ids, alignments=ali, num_classes=num_classes + ) + min_len = min(nnet_output.shape[1], mask.shape[1]) + ali_model_scale = 0.8 + + nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :] + + supervisions = batch["supervisions"] + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // subsampling_factor, + supervisions["num_frames"] // subsampling_factor, + ), + 1, + ).to(torch.int32) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=10000, + subsampling_factor=subsampling_factor, + ) + + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + hyps = [" ".join(s) for s in hyps] + print(hyps) + print(supervisions["text"]) + break + + +def show_cut_ids(): + # The purpose of this function is to check that + # for each utterance in the training set, there is + # a corresponding alignment. + # + # After generating a1.txt and b1.txt + # You can use + # wc -l a1.txt b1.txt + # which should show the same number of lines. + # + # cat a1.txt | sort | uniq > a11.txt + # cat b1.txt | sort | uniq > b11.txt + # + # md5sum a11.txt b11.txt + # which should show the identical hash + # + # diff a11.txt b11.txt + # should print nothing + + subsampling_factor, ali = load_alignments(ali_filename) + with open("a1.txt", "w") as f: + for key in ali: + f.write(f"{key}\n") + + # dl = get_dataloader() + cuts_train = ( + load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz") + + load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz") + + load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz") + ) + + ans = [] + for cut in cuts_train: + ans.append(cut.id) + with open("b1.txt", "w") as f: + for line in ans: + f.write(f"{line}\n") + + +if __name__ == "__main__": + test() From 28f1aabf99ffa65abeac215e215694e36afd1f19 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 28 Sep 2021 19:51:31 +0800 Subject: [PATCH 13/19] Minor fixes. --- .../ASR/conformer_mmi/train-with-attention.py | 836 ++++++++++++++++++ egs/librispeech/ASR/conformer_mmi/train.py | 34 +- icefall/mmi.py | 18 +- 3 files changed, 878 insertions(+), 10 deletions(-) create mode 100755 egs/librispeech/ASR/conformer_mmi/train-with-attention.py diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py new file mode 100755 index 0000000000..a66f776764 --- /dev/null +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -0,0 +1,836 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Dict, Optional + +import k2 +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.utils import fix_random_seed +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.lexicon import Lexicon +from icefall.mmi import LFMMILoss +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=50, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_mmi/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali_500", + help="""This folder is expected to contain + two files, train-960.pt and valid.pt, which + contain framewise alignment information for + the training set and validation set. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - exp_dir: It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + + - lang_dir: It contains language related input files such as + "lexicon.txt" + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - weight_decay: The weight_decay for the optimizer. + + - lr_factor: The lr_factor for Noam optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "exp_dir": Path("conformer_mmi/exp_500_with_attention"), + "lang_dir": Path("data/lang_bpe_500"), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 6, # will change it to 8 after some batches (see code) + "reduction": "sum", + "use_double_scores": True, + # "att_rate": 0.0, + # "num_decoder_layers": 0, + "att_rate": 0.7, + "num_decoder_layers": 6, + # parameters for Noam + "weight_decay": 1e-6, + "lr_factor": 5.0, + "warm_step": 80000, + "use_pruned_intersect": False, + "den_scale": 1.0, + "use_ali_until": 13000, # use alignments before this number of batches + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: MmiTrainingGraphCompiler, + is_training: bool, + ali: Optional[Dict[str, torch.Tensor]], +): + """ + Compute LF-MMI loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + ali: + Precomputed alignments. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `LFMMILoss.forward()` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if ali is not None and params.batch_idx_train < params.use_ali_until: + cut_ids = [cut.id for cut in supervisions["cut"]] + + # As encode_supervisions reorders cuts, we need + # also to reorder cut IDs here + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] + + # Check that new2old is just a permutation, + # i.e., each cut contains only one utterance + new2old.sort() + assert new2old == torch.arange(len(new2old)).tolist() + mask = lookup_alignments( + cut_ids=cut_ids, + alignments=ali, + num_classes=nnet_output.shape[2], + ).to(nnet_output) + + min_len = min(nnet_output.shape[1], mask.shape[1]) + ali_scale = 500.0 / (params.batch_idx_train + 500) + + nnet_output = nnet_output.clone() + nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] + + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): + logging.info("Change beam size to 8") + params.beam_size = 8 + else: + params.beam_size = 6 + + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + use_pruned_intersect=params.use_pruned_intersect, + den_scale=params.den_scale, + beam_size=params.beam_size, + ) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) + + if params.att_rate != 0.0: + token_ids = graph_compiler.texts_to_ids(texts) + with torch.set_grad_enabled(is_training): + if hasattr(model, "module"): + att_loss = model.module.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss + else: + loss = mmi_loss + att_loss = torch.tensor([0]) + + # train_frames and valid_frames are used for printing. + if is_training: + params.train_frames = supervision_segments[:, 2].sum().item() + else: + params.valid_frames = supervision_segments[:, 2].sum().item() + + assert loss.requires_grad == is_training + + return loss, mmi_loss.detach(), att_loss.detach() + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: MmiTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + ali: Optional[Dict[str, torch.Tensor]] = None, +) -> None: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = 0.0 + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + tot_frames = 0.0 + for batch_idx, batch in enumerate(valid_dl): + loss, mmi_loss, att_loss = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ali=ali, + ) + assert loss.requires_grad is False + assert mmi_loss.requires_grad is False + assert att_loss.requires_grad is False + + loss_cpu = loss.detach().cpu().item() + tot_loss += loss_cpu + + tot_mmi_loss += mmi_loss.detach().cpu().item() + tot_att_loss += att_loss.detach().cpu().item() + + tot_frames += params.valid_frames + + if world_size > 1: + s = torch.tensor( + [tot_loss, tot_mmi_loss, tot_att_loss, tot_frames], + device=loss.device, + ) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + s = s.cpu().tolist() + tot_loss = s[0] + tot_mmi_loss = s[1] + tot_att_loss = s[2] + tot_frames = s[3] + + params.valid_loss = tot_loss / tot_frames + params.valid_mmi_loss = tot_mmi_loss / tot_frames + params.valid_att_loss = tot_att_loss / tot_frames + + if params.valid_loss < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = params.valid_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: MmiTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + train_ali: Optional[Dict[str, torch.Tensor]], + valid_ali: Optional[Dict[str, torch.Tensor]], + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + train_ali: + Precomputed alignments for the training set. + valid_ali: + Precomputed alignments for the validation set. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = 0.0 # sum of losses over all batches + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches + params.tot_loss = 0.0 + params.tot_frames = 0.0 + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, mmi_loss, att_loss = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ali=train_ali, + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + loss_cpu = loss.detach().cpu().item() + mmi_loss_cpu = mmi_loss.detach().cpu().item() + att_loss_cpu = att_loss.detach().cpu().item() + + tot_frames += params.train_frames + tot_loss += loss_cpu + tot_mmi_loss += mmi_loss_cpu + tot_att_loss += att_loss_cpu + + params.tot_frames += params.train_frames + params.tot_loss += loss_cpu + + tot_avg_loss = tot_loss / tot_frames + tot_avg_mmi_loss = tot_mmi_loss / tot_frames + tot_avg_att_loss = tot_att_loss / tot_frames + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, " + f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " + f"batch avg loss {loss_cpu/params.train_frames:.4f}, " + f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, " + f"total avg att loss: {tot_avg_att_loss:.4f}, " + f"total avg loss: {tot_avg_loss:.4f}, " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_mmi_loss", + mmi_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_att_loss", + att_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_mmi_loss", + tot_avg_mmi_loss, + params.batch_idx_train, + ) + + tb_writer.add_scalar( + "train/tot_avg_att_loss", + tot_avg_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ali=valid_ali, + ) + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"valid mmi loss {params.valid_mmi_loss:.4f}," + f"valid att loss {params.valid_att_loss:.4f}," + f"valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " + f"best valid epoch: {params.best_valid_epoch}" + ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_mmi_loss", + params.valid_mmi_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_att_loss", + params.valid_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) + + params.train_loss = params.tot_loss / params.tot_frames + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = MmiTrainingGraphCompiler( + params.lang_dir, + uniq_filename="lexicon.txt", + device=device, + oov="", + sos_id=1, + eos_id=1, + ) + + logging.info("About to create model") + if params.att_rate == 0: + assert params.num_decoder_layers == 0, f"{params.num_decoder_layers}" + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + train_960_ali_filename = Path(params.ali_dir) / "train-960.pt" + if ( + params.batch_idx_train < params.use_ali_until + and train_960_ali_filename.is_file() + ): + logging.info("Use pre-computed alignments") + subsampling_factor, train_ali = load_alignments(train_960_ali_filename) + assert subsampling_factor == params.subsampling_factor + assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723" + + valid_ali_filename = Path(params.ali_dir) / "valid.pt" + subsampling_factor, valid_ali = load_alignments(valid_ali_filename) + assert subsampling_factor == params.subsampling_factor + + train_ali = convert_alignments_to_tensor(train_ali, device=device) + valid_ali = convert_alignments_to_tensor(valid_ali, device=device) + else: + logging.info("Not using alignments") + train_ali = None + valid_ali = None + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): + # Delete the alignments to save memory + train_ali = None + valid_ali = None + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + train_ali=train_ali, + valid_ali=valid_ali, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 900d109a82..18cc80a9aa 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -189,18 +189,21 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, - "num_decoder_layers": 6, # parameters for loss - "beam_size": 10, + "beam_size": 6, # will change it to 8 after some batches (see code) "reduction": "sum", "use_double_scores": True, - "att_rate": 0.7, + "att_rate": 0.0, + "num_decoder_layers": 0, + # "att_rate": 0.7, + # "num_decoder_layers": 6, # parameters for Noam "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, "use_pruned_intersect": False, "den_scale": 1.0, + "use_ali_until": 13000, # use alignments before this number of batches } ) @@ -342,7 +345,7 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) - if ali is not None and params.batch_idx_train < 4000: + if ali is not None and params.batch_idx_train < params.use_ali_until: cut_ids = [cut.id for cut in supervisions["cut"]] # As encode_supervisions reorders cuts, we need @@ -366,10 +369,20 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): + logging.info("Change beam size to 8") + params.beam_size = 8 + else: + params.beam_size = 6 + loss_fn = LFMMILoss( graph_compiler=graph_compiler, use_pruned_intersect=params.use_pruned_intersect, den_scale=params.den_scale, + beam_size=params.beam_size, ) dense_fsa_vec = k2.DenseFsaVec( @@ -698,6 +711,9 @@ def run(rank, world_size, args): ) logging.info("About to create model") + if params.att_rate == 0: + assert params.num_decoder_layers == 0, f"{params.num_decoder_layers}" + model = Conformer( num_features=params.feature_dim, nhead=params.nhead, @@ -727,7 +743,10 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) train_960_ali_filename = Path(params.ali_dir) / "train-960.pt" - if params.batch_idx_train < 4000 and train_960_ali_filename.is_file(): + if ( + params.batch_idx_train < params.use_ali_until + and train_960_ali_filename.is_file() + ): logging.info("Use pre-computed alignments") subsampling_factor, train_ali = load_alignments(train_960_ali_filename) assert subsampling_factor == params.subsampling_factor @@ -750,7 +769,10 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train > 4000 and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None diff --git a/icefall/mmi.py b/icefall/mmi.py index f9ba46df97..2c479fc2c9 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -12,6 +12,7 @@ def _compute_mmi_loss_exact_optimized( texts: List[str], graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, + beam_size: float = 8.0, ) -> torch.Tensor: """ The function name contains `exact`, which means it uses a version of @@ -79,7 +80,7 @@ def _compute_mmi_loss_exact_optimized( num_den_lats = k2.intersect_dense( num_den_reordered_graphs, dense_fsa_vec, - output_beam=10.0, + output_beam=beam_size, a_to_b_map=a_to_b_map, ) @@ -100,6 +101,7 @@ def _compute_mmi_loss_exact_non_optimized( texts: List[str], graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, + beam_size: float = 8.0, ) -> torch.Tensor: """ See :func:`_compute_mmi_loss_exact_optimized` for the meaning @@ -113,8 +115,12 @@ def _compute_mmi_loss_exact_non_optimized( num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) # TODO: pass output_beam as function argument - num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) - den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0) + num_lats = k2.intersect_dense( + num_graphs, dense_fsa_vec, output_beam=beam_size + ) + den_lats = k2.intersect_dense( + den_graphs, dense_fsa_vec, output_beam=beam_size + ) num_tot_scores = num_lats.get_tot_scores( log_semiring=True, use_double_scores=True @@ -135,6 +141,7 @@ def _compute_mmi_loss_pruned( texts: List[str], graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, + beam_size: float = 8.0, ) -> torch.Tensor: """ See :func:`_compute_mmi_loss_exact_optimized` for the meaning @@ -156,7 +163,7 @@ def _compute_mmi_loss_pruned( den_graphs, dense_fsa_vec, search_beam=20.0, - output_beam=8.0, + output_beam=beam_size, min_active_states=30, max_active_states=10000, ) @@ -187,11 +194,13 @@ def __init__( graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool = False, den_scale: float = 1.0, + beam_size: float = 8.0, ): super().__init__() self.graph_compiler = graph_compiler self.den_scale = den_scale self.use_pruned_intersect = use_pruned_intersect + self.beam_size = beam_size def forward( self, @@ -219,4 +228,5 @@ def forward( texts=texts, graph_compiler=self.graph_compiler, den_scale=self.den_scale, + beam_size=self.beam_size, ) From b8dbad5156e41ecfc848feaf7c218dca9cee823e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 14:38:07 +0800 Subject: [PATCH 14/19] Update decoding script. --- egs/librispeech/ASR/conformer_mmi/decode.py | 125 ++++++++++++++---- .../ASR/conformer_mmi/train-with-attention.py | 2 +- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index dc2e449c22..79012c98f4 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -23,6 +23,7 @@ from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -77,6 +78,9 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path @@ -106,7 +110,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -122,7 +126,7 @@ def get_parser(): type=str2bool, default=False, help="""When enabled, the averaged model is saved to - conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved. + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. pretrained.pt contains a dict {"model": model.state_dict()}, which can be loaded by `icefall.checkpoint.load_checkpoint()`. """, @@ -131,17 +135,24 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="conformer_mmi/exp", + default="conformer_mmi/exp_500", help="The experiment dir", ) parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_500", help="The lang dir", ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="Number of attention decoder layers", + ) + return parser @@ -156,7 +167,6 @@ def get_params() -> AttributeDict: "feature_dim": 80, "nhead": 8, "attention_dim": 512, - "num_decoder_layers": 6, # parameters for decoding "search_beam": 20, "output_beam": 8, @@ -171,13 +181,15 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -202,7 +214,11 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -221,7 +237,10 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -241,9 +260,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -252,6 +279,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -262,12 +307,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -281,9 +326,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -305,7 +350,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -331,7 +376,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -344,7 +389,7 @@ def decode_one_batch( ans[lm_scale_str] = hyps else: for lm_scale in lm_scale_list: - ans[lm_scale_str] = [[] * lattice.shape[0]] + ans["empty"] = [[] * lattice.shape[0]] return ans @@ -352,12 +397,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -368,7 +415,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: @@ -403,6 +454,8 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, G=G, @@ -481,11 +534,11 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - params.exp_dir = Path(params.exp_dir) - params.lang_dir = Path(params.lang_dir) setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") logging.info("Decoding started") @@ -510,14 +563,26 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() if params.method in ( "nbest-rescoring", @@ -607,6 +672,8 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index a66f776764..a3e2668147 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -373,7 +373,7 @@ def compute_loss( params.batch_idx_train > params.use_ali_until and params.beam_size < 8 ): - logging.info("Change beam size to 8") + # logging.info("Change beam size to 8") params.beam_size = 8 else: params.beam_size = 6 From f383666c400dcae1dc04f4081966255d27c8011a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 14:53:32 +0800 Subject: [PATCH 15/19] Add doc about how to check and use extracted alignments. --- egs/librispeech/ASR/conformer_ctc/README.md | 24 +++++++++++++++++++++ test/test_ali.py | 12 +++++------ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 164c3e53e1..cafc63bb6d 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -51,3 +51,27 @@ in `conformer_ctc/train.py`. Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. **TODO:** Add doc about how to use the extracted alignment in the other pull-request. + +### Step 3: Check your extracted alignments + +There is a file `test_ali.py` in `icefall/test` that can be used to test your +alignments. It uses pre-computed alignments to modify a randomly generated +`nnet_output` and it checks that we can decode the correct transcripts +from the resulting `nnet_output`. + +You should get something like the following if you run that script: + +``` +$ ./test/test_ali.py +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +``` + +### Step 4: Use your alignments in training + +Please refer to `conformer_mmi/train.py` for how usage. Some useful +functions are: + +- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py` +- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors +- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances. diff --git a/test/test_ali.py b/test/test_ali.py index e8516e6dc6..d8ada33e85 100755 --- a/test/test_ali.py +++ b/test/test_ali.py @@ -45,12 +45,12 @@ egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" lang_dir = egs_dir / "data/lang_bpe_500" # cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz" -cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz" +# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz" # cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz" -ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt" +# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt" -# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz" -# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt" +cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz" +ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt" def data_exists(): @@ -62,7 +62,7 @@ def get_dataloader(): cuts_train = cuts_train.with_features_path_prefix(egs_dir) train_sampler = SingleCutSampler( cuts_train, - max_duration=200, + max_duration=40, shuffle=False, ) @@ -162,7 +162,7 @@ def test(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=20, output_beam=8, From 00dac43130a449fe1140805bdd2adbc865de41b1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 14:58:07 +0800 Subject: [PATCH 16/19] Fix style issues. --- .flake8 | 5 +++-- egs/librispeech/ASR/conformer_mmi/train-with-attention.py | 3 ++- egs/librispeech/ASR/conformer_mmi/train.py | 3 ++- test/test_lexicon.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.flake8 b/.flake8 index 3f1227b9b9..b8f0e4715d 100644 --- a/.flake8 +++ b/.flake8 @@ -4,8 +4,9 @@ statistics=true max-line-length = 80 per-file-ignores = # line too long - egs/librispeech/ASR/conformer_ctc/conformer.py: E501, + egs/librispeech/ASR/*/conformer.py: E501, exclude = .git, - **/data/** + **/data/**, + icefall/shared/make_kn_lm.py diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index a3e2668147..8b89940594 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -203,7 +203,8 @@ def get_params() -> AttributeDict: "warm_step": 80000, "use_pruned_intersect": False, "den_scale": 1.0, - "use_ali_until": 13000, # use alignments before this number of batches + # use alignments before this number of batches + "use_ali_until": 13000, } ) diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 18cc80a9aa..6580792ff5 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -203,7 +203,8 @@ def get_params() -> AttributeDict: "warm_step": 80000, "use_pruned_intersect": False, "den_scale": 1.0, - "use_ali_until": 13000, # use alignments before this number of batches + # use alignments before this number of batches + "use_ali_until": 13000, } ) diff --git a/test/test_lexicon.py b/test/test_lexicon.py index 2a16db2260..69867efc7f 100755 --- a/test/test_lexicon.py +++ b/test/test_lexicon.py @@ -26,8 +26,8 @@ import shutil import sys from pathlib import Path +from typing import List -import k2 import sentencepiece as spm from icefall.lexicon import UniqLexicon From 0663b97599fd94aebc3dbad161e8e256093eabfd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 15:00:37 +0800 Subject: [PATCH 17/19] Fix typos. --- egs/librispeech/ASR/conformer_ctc/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index cafc63bb6d..dfa522287c 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -27,10 +27,10 @@ avg=15 --bucketing-sampler 0 \ --full-libri 1 \ --exp-dir conformer_ctc/exp \ - --lang-dir data/lang_bpe_5000 \ - --ali-dir data/ali_5000 + --lang-dir data/lang_bpe_500 \ + --ali-dir data/ali_500 ``` -and you will get four files inside the folder `data/ali_5000`: +and you will get four files inside the folder `data/ali_500`: ``` $ ls -lh data/ali_500 From 3ac9b4595d4835c451e8cf03ae354d6a670fadc3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 15:04:27 +0800 Subject: [PATCH 18/19] Fix style issues. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0d80ed4d22..01ff869db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,4 +8,5 @@ exclude = ''' \.git | \.github )/ + | make_kn_lm.py ''' From f76ef6e58adad4da72f15fcb1fc6e179dc975081 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 18 Oct 2021 15:19:57 +0800 Subject: [PATCH 19/19] Disable macOS tests for now. --- .github/workflows/test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6b72b5a0c7..b5c8cfcfad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,7 +29,9 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-18.04, macos-10.15] + # os: [ubuntu-18.04, macos-10.15] + # disable macOS test for now. + os: [ubuntu-18.04] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] k2-version: ["1.9.dev20210919"]