Skip to content

Commit

Permalink
Refactoring (#4)
Browse files Browse the repository at this point in the history
* Fix an error in TDNN-LSTM training.

* WIP: Refactoring

* Refactor transformer.py

* Remove unused code.

* Minor fixes.
  • Loading branch information
csukuangfj authored Aug 4, 2021
1 parent cf8d762 commit 5a0b9bc
Show file tree
Hide file tree
Showing 23 changed files with 964 additions and 644 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ path.sh
exp
exp*/
*.pt
download/
16 changes: 11 additions & 5 deletions egs/librispeech/ASR/conformer_ctc/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,26 @@ def __init__(
# and throws an error without this change.
self.after_norm = identity

def encode(
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x: Tensor of dimension (batch_size, num_features, input_length).
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)

x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
Expand Down
63 changes: 49 additions & 14 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.nn as nn
from conformer import Conformer

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
Expand Down Expand Up @@ -62,7 +63,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
Expand All @@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "1best",
"method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 100,
Expand All @@ -100,6 +101,8 @@ def decode_one_batch(
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
Expand Down Expand Up @@ -133,6 +136,10 @@ def decode_one_batch(
for the format of the `batch`.
lexicon:
It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
Expand All @@ -147,15 +154,10 @@ def decode_one_batch(
feature = feature.to(device)
# at entry, feature is [N, T, C]

feature = feature.permute(0, 2, 1) # now feature is [N, C, T]

supervisions = batch["supervisions"]

nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, C, T]

nnet_output = nnet_output.permute(0, 2, 1)
# now nnet_output is [N, T, C]
# nnet_output is [N, T, C]

supervision_segments = torch.stack(
(
Expand Down Expand Up @@ -227,6 +229,8 @@ def decode_one_batch(
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
Expand All @@ -245,6 +249,8 @@ def decode_dataset(
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
Expand All @@ -260,6 +266,10 @@ def decode_dataset(
The decoding graph.
lexicon:
It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
Expand Down Expand Up @@ -287,6 +297,8 @@ def decode_dataset(
batch=batch,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)

for lm_scale, hyps in hyps_dict.items():
Expand Down Expand Up @@ -314,20 +326,31 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")

# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer

logging.info("Wrote detailed error stats to {}".format(errs_filename))
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)

test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
Expand Down Expand Up @@ -367,15 +390,22 @@ def main():

logging.info(f"device: {device}")

HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt"))
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id

HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False

if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()

# HLG = k2.ctc_topo(4999).to(device)

if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
Expand Down Expand Up @@ -461,6 +491,8 @@ def main():
HLG=HLG,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)

save_results(
Expand All @@ -470,5 +502,8 @@ def main():
logging.info("Done!")


torch.set_num_threads(1)
torch.set_num_interop_threads(1)

if __name__ == "__main__":
main()
144 changes: 144 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/subsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import torch.nn as nn


class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""

def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
return x


class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""

def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
super().__init__()

cur_channels = 1
layers = []
block_dims = [32, 64]

# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim

self.layers = nn.Sequential(*layers)

self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x
33 changes: 33 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3

from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch


def test_conv2d_subsampling():
N = 3
odim = 2

for T in range(7, 19):
for idim in range(7, 20):
model = Conv2dSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim


def test_vgg_subsampling():
N = 3
odim = 2

for T in range(7, 19):
for idim in range(7, 20):
model = VggSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim
Loading

0 comments on commit 5a0b9bc

Please sign in to comment.