Skip to content

Commit

Permalink
split save_results() -> save_asr_output() + save_wer_results()
Browse files Browse the repository at this point in the history
- the idea is to support `--skip-scoring` argument passed to a decoding
  script
- created for Transducer decoding (non-streaming, streaming)
- it can be done also for CTC decoding... (not yet)

- also added `--label` for extra label in `streaming_decode.py`
- and also added `set_caching_enabled(True)`, which has no effect on
  librispeech, but it leads to faster runtime on DBs with long
  recordings (assuming `librispeech/zipformer` scripts are the
  example scripts for other setups)

- still need to run the code to make sure it works properly
- at this point, interested in getting early feedback from you...

Thank you.
  • Loading branch information
KarelVesely84 committed Aug 8, 2024
1 parent 3b257dd commit 9de503f
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 71 deletions.
132 changes: 83 additions & 49 deletions egs/librispeech/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params

from icefall import ContextGraph, LmScorer, NgramLm
Expand Down Expand Up @@ -369,6 +370,14 @@ def get_parser():
modified_beam_search_LODR.
""",
)

parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -590,21 +599,23 @@ def decode_one_batch(
)
hyps.append(sp.decode(hyp).split())

# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
prefix += f"_beam-{params.beam}"
prefix += f"_max-contexts-{params.max_contexts}"
prefix += f"_max-states-{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
prefix += f"_num-paths-{params.num_paths}"
prefix += f"_nbest-scale-{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}"

return {key: hyps}
return {prefix: hyps}
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
prefix += f"_beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
Expand All @@ -617,10 +628,11 @@ def decode_one_batch(
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
prefix += f"_context-score-{params.context_score}"
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
prefix += f"_beam-size-{params.beam_size}"
return {prefix: hyps}


def decode_dataset(
Expand Down Expand Up @@ -707,46 +719,58 @@ def decode_dataset(
return results


def save_results(
def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)

recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"

results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
store_transcripts(filename=recogs_filename, texts=results)

logging.info(f"The transcripts are stored in {recogs_filename}")


def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
):
"""
Save WER and per-utterance word alignments.
"""
test_set_wers = dict()
for key, results in results_dict.items():
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
fd, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer

logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info(f"Wrote detailed error stats to {errs_filename}")

test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"

with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
print(f"{key}\t{val}", file=fd)

s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
s += f"{key}\t{val}{note}\n"
note = ""
logging.info(s)

Expand All @@ -762,6 +786,9 @@ def main():
params = get_params()
params.update(vars(args))

# enable AudioCache
set_caching_enabled(True) # lhotse

assert params.decoding_method in (
"greedy_search",
"beam_search",
Expand All @@ -783,9 +810,9 @@ def main():
params.has_contexts = False

if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"

if params.causal:
assert (
Expand All @@ -794,40 +821,40 @@ def main():
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}"

if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"_beam-{params.beam}"
params.suffix += f"_max-contexts-{params.max_contexts}"
params.suffix += f"_max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"_nbest-scale-{params.nbest_scale}"
params.suffix += f"_num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"_context-{params.context_size}"
params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}"

if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}"

if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)

if params.use_averaged_model:
params.suffix += "-use-averaged-model"
params.suffix += "_use-averaged-model"

setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
Expand Down Expand Up @@ -1038,12 +1065,19 @@ def main():
ngram_lm_scale=ngram_lm_scale,
)

save_results(
save_asr_output(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)

logging.info("Done!")


Expand Down
Loading

0 comments on commit 9de503f

Please sign in to comment.