Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPGISpeech recipe #334

Merged
merged 20 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions egs/spgispeech/ASR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPGISpeech

SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
length to allow easy training for speech recognition systems. Calls represent a broad
cross-section of international business English; SPGISpeech contains approximately 50,000
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio.

Transcription text represents the output of several stages of manual post-processing.
As such, the text contains polished English orthography following a detailed style guide,
including proper casing, punctuation, and denormalized non-standard words such as numbers
and acronyms, making SPGISpeech suited for training fully formatted end-to-end models.

Official reference:

O’Neill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
end-to-end speech recognition. ArXiv, abs/2104.02014.

ArXiv link: https://arxiv.org/abs/2104.02014

## Performance Record

| Decoding method | val WER | val CER |
|---------------------------|------------|---------|
| greedy search | 2.40 | 0.99 |
| modified beam search | 2.24 | 0.91 |
| fast beam search | 2.35 | 0.97 |

See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
73 changes: 73 additions & 0 deletions egs/spgispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
## Results

### SPGISpeech BPE training results (Pruned Transducer)

#### 2022-05-11

#### Conformer encoder + embedding decoder

Conformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).

The WERs are

| | dev | val | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |

**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
transcripts are orthographic or normalized. These WERs correspond to the normalized transcription
scenario.

The training command for reproducing is given below:

```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 20 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--max-duration 200 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
--use-fp16 True
```

The decoding command is:
```
# greedy search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
# modified beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```

Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>

The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>
Empty file.
1 change: 1 addition & 0 deletions egs/spgispeech/ASR/local/compile_hlg.py
104 changes: 104 additions & 0 deletions egs/spgispeech/ASR/local/compute_fbank_musan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""

import logging
from pathlib import Path

import torch
from lhotse import LilcomChunkyWriter, CutSet, combine
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")

sampling_rate = 16000
num_mel_bins = 80

extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)

dataset_parts = (
"music",
"speech",
"noise",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None

musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"

if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return

logging.info("Extracting features for Musan")

# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_musan",
manifest_path=src_dir / f"cuts_musan.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()
145 changes: 145 additions & 0 deletions egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file computes fbank features of the SPGISpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
from pathlib import Path
from tqdm import tqdm

import torch
from lhotse import load_manifest_lazy, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions,
)
from lhotse.manipulation import combine

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=20,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
parser.add_argument(
"--train",
action="store_true",
help="If set, only compute features for the train set.",
)

return parser.parse_args()


def compute_fbank_spgispeech(args):
assert args.train or args.test, "Either train or test must be set."

src_dir = Path("data/manifests")
output_dir = Path("data/fbank")

sampling_rate = 16000
num_mel_bins = 80

extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)

if args.train:
logging.info(f"Processing train")
cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz")
chunk_size = len(cut_set) // args.num_splits
cut_sets = cut_set.split_lazy(
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
chunk_size=chunk_size,
)
start = args.start
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
num_digits = len(str(args.num_splits))
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing train split {i}")
cs = cut_sets[i].compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_train_{idx}",
manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)

if args.test:
for partition in ["dev", "val"]:
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{partition}",
manifest_path=src_dir / f"cuts_{partition}.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_spgispeech(args)
1 change: 1 addition & 0 deletions egs/spgispeech/ASR/local/prepare_lang.py
1 change: 1 addition & 0 deletions egs/spgispeech/ASR/local/prepare_lang_bpe.py
Loading