Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] A LibriTTS recipe on both ASR & Neural Codec Tasks #1746

Draft
wants to merge 31 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
dd82686
init commit
JinZr Sep 4, 2024
6e4a9ea
a little bit coarse commit
JinZr Sep 5, 2024
2df992f
fixed a typo
JinZr Sep 5, 2024
91f7b1c
sort of fixed DDP training issue
JinZr Sep 6, 2024
2e5055a
minor updates
JinZr Sep 6, 2024
0150961
minor fixes
JinZr Sep 6, 2024
8da57a0
black formatted
JinZr Sep 6, 2024
4483c6e
tensorboard should work properly
JinZr Sep 6, 2024
12c7a16
minor updates
JinZr Sep 6, 2024
c236757
* added script for inference
JinZr Sep 7, 2024
d45b400
minor updates
JinZr Sep 8, 2024
c43977e
black formatted
JinZr Sep 8, 2024
1e65a97
added pesq and stoi for reconstruction performance evaluation
JinZr Sep 8, 2024
f9340cc
refactored loss functions
JinZr Oct 5, 2024
e788bb4
making MSD and MPD optional
JinZr Oct 6, 2024
d83ce89
fixed loss normalization & scaling factors
JinZr Oct 6, 2024
58f6562
added scheduler w/ warmup
JinZr Oct 6, 2024
01cc307
fixed loss functions & scaling factors
JinZr Oct 6, 2024
93eedce
Merge branch 'k2-fsa:master' into dev/asr/libritts
JinZr Oct 6, 2024
b65eba2
fixed script for inference
JinZr Oct 7, 2024
266e840
fixed ``+x`` permission
JinZr Oct 7, 2024
32a7d22
minor updates to the scripts
JinZr Oct 7, 2024
f074487
minor updates
JinZr Oct 7, 2024
156af46
applied text norm to valid & test cuts
JinZr Oct 7, 2024
43267e3
black formatted
JinZr Oct 8, 2024
2356621
minor updates
JinZr Oct 9, 2024
df87a0f
updated train.py
JinZr Oct 9, 2024
5492a6a
comments updated
JinZr Oct 12, 2024
cd96f63
added text norm for other decoding scripts
JinZr Oct 12, 2024
74a738f
comments updated
JinZr Oct 12, 2024
7eee6b9
updated default param
JinZr Oct 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions egs/librispeech/ASR/zipformer/attention_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def forward(
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
attn_mask = torch.logical_or(
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len)
) # (batch, seq_len, seq_len)

if memory is not None:
Expand Down Expand Up @@ -367,7 +367,9 @@ def __init__(
self.num_heads = num_heads
self.head_dim = attention_dim // num_heads
assert self.head_dim * num_heads == attention_dim, (
self.head_dim, num_heads, attention_dim
self.head_dim,
num_heads,
attention_dim,
)
self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics.
Expand Down Expand Up @@ -437,15 +439,19 @@ def forward(
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)

if attn_mask is not None:
assert (
attn_mask.shape == (batch, 1, src_len)
or attn_mask.shape == (batch, tgt_len, src_len)
assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
batch,
tgt_len,
src_len,
), attn_mask.shape
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
attn_weights = attn_weights.masked_fill(
attn_mask.unsqueeze(1), float("-inf")
)

attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand All @@ -456,7 +462,11 @@ def forward(

# (batch * head, tgt_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
assert attn_output.shape == (
batch * num_heads,
tgt_len,
head_dim,
), attn_output.shape

attn_output = attn_output.transpose(0, 1).contiguous()
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
Expand Down
12 changes: 7 additions & 5 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def build_inputs_outputs(tensors, i):

add_meta_data(filename=encoder_filename, meta_data=meta_data)


def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
Expand Down Expand Up @@ -754,30 +755,31 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
if params.fp16:
from onnxconverter_common import float16

logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)
onnx.save(encoder_fp16, encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)
onnx.save(decoder_fp16, decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)
onnx.save(joiner_fp16, joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
Expand Down
8 changes: 4 additions & 4 deletions egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,23 +592,23 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
if params.fp16:
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)
onnx.save(encoder_fp16, encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)
onnx.save(decoder_fp16, decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)
onnx.save(joiner_fp16, joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
Expand Down
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/compile_hlg.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/compile_lg.py
160 changes: 160 additions & 0 deletions egs/libritts/ASR/local/compute_fbank_libritts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao,)
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file computes fbank features of the LibriTTS dataset.
It looks for manifests in the directory data/manifests.

The generated fbank features are saved in data/fbank.
"""

import argparse
import logging
import os
from pathlib import Path
from typing import Optional

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def get_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=True,
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""",
)

return parser.parse_args()


def compute_fbank_libritts(
dataset: Optional[str] = None,
sampling_rate: int = 24000,
perturb_speed: Optional[bool] = True,
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(32, os.cpu_count())

num_mel_bins = 80

if dataset is None:
dataset_parts = (
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)

prefix = "libritts"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None

assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if sampling_rate != 24000:
logging.info(f"Resampling audio to {sampling_rate}Hz")
cut_set = cut_set.resample(sampling_rate)
if "train" in partition:
if perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)

cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))

compute_fbank_libritts(
dataset=args.dataset,
sampling_rate=args.sampling_rate,
perturb_speed=args.perturb_speed,
)
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/compute_fbank_musan.py
Loading
Loading