Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

librispeech_decode : split save_results() -> save_asr_output() + save_wer_results() #1712

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 65 additions & 29 deletions egs/librispeech/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params

from icefall.checkpoint import (
Expand Down Expand Up @@ -296,6 +297,13 @@ def get_parser():
""",
)

parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets)."""
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -455,7 +463,7 @@ def decode_one_batch(
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
return {key: hyps} # note: returns words

if params.decoding_method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram(
Expand Down Expand Up @@ -492,27 +500,27 @@ def decode_one_batch(
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
return {key: hyps}

if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
key = "no-rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa

hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
return {key: hyps} # note: returns BPE tokens

assert params.decoding_method in [
"nbest-rescoring",
Expand Down Expand Up @@ -646,7 +654,27 @@ def decode_dataset(
return results


def save_results(
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():

recogs_filename = (
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)

results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)

logging.info(f"The transcripts are stored in {recogs_filename}")


def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
Expand All @@ -661,32 +689,30 @@ def save_results(

test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")

# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))

logging.info(f"Wrote detailed error stats to {errs_filename}")

test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"

with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
print(f"{key}\t{val}", file=fd)

s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)

Expand All @@ -705,6 +731,9 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))

# enable AudioCache
set_caching_enabled(True) # lhotse

assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding",
Expand All @@ -719,9 +748,9 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method

if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"

if params.causal:
assert (
Expand All @@ -730,11 +759,11 @@ def main():
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"

if params.use_averaged_model:
params.suffix += "-use-averaged-model"
params.suffix += "_use-averaged-model"

setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
Expand Down Expand Up @@ -940,12 +969,19 @@ def main():
G=G,
)

save_results(
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

logging.info("Done!")


Expand Down
Loading
Loading