From 1730fce688aa4cb6c3162ed860e29c6a72da1604 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Tue, 13 Aug 2024 17:02:14 +0200 Subject: [PATCH] split `save_results()` -> `save_asr_output()` + `save_wer_results()` (#1712) - the idea is to support `--skip-scoring` argument passed to a decoding script - created for Transducer decoding (non-streaming, streaming) - it can be done also for CTC decoding... (not yet) - also added `--label` for extra label in `streaming_decode.py` - and also added `set_caching_enabled(True)`, which has no effect on librispeech, but it leads to faster runtime on DBs with long recordings (assuming `librispeech/zipformer` scripts are the example scripts for other setups) --- egs/librispeech/ASR/zipformer/ctc_decode.py | 94 +++++++++---- egs/librispeech/ASR/zipformer/decode.py | 132 +++++++++++------- .../ASR/zipformer/streaming_decode.py | 88 +++++++++--- 3 files changed, 214 insertions(+), 100 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 435a79e7fc..9db4299592 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -120,6 +120,7 @@ import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -296,6 +297,13 @@ def get_parser(): """, ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + add_model_arguments(parser) return parser @@ -455,7 +463,7 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] key = "ctc-decoding" - return {key: hyps} + return {key: hyps} # note: returns words if params.decoding_method == "attention-decoder-rescoring-no-ngram": best_path_dict = rescore_with_attention_decoder_no_ngram( @@ -492,7 +500,7 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa return {key: hyps} if params.decoding_method in ["1best", "nbest"]: @@ -500,7 +508,7 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - key = "no_rescore" + key = "no-rescore" else: best_path = nbest_decoding( lattice=lattice, @@ -508,11 +516,11 @@ def decode_one_batch( use_double_scores=params.use_double_scores, nbest_scale=params.nbest_scale, ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} + return {key: hyps} # note: returns BPE tokens assert params.decoding_method in [ "nbest-rescoring", @@ -646,7 +654,27 @@ def decode_dataset( return results -def save_results( +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + ) + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], @@ -661,32 +689,30 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -705,6 +731,9 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "ctc-greedy-search", "ctc-decoding", @@ -719,9 +748,9 @@ def main(): params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -730,11 +759,11 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -940,12 +969,19 @@ def main(): G=G, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index df2d555a09..cbfb3728e6 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -121,6 +121,7 @@ modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -369,6 +370,14 @@ def get_parser(): modified_beam_search_LODR. """, ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + add_model_arguments(parser) return parser @@ -590,21 +599,23 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - return {key: hyps} + return {prefix: hyps} elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" + prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search_lm_rescore", "modified_beam_search_lm_rescore_LODR", @@ -617,10 +628,11 @@ def decode_one_batch( return ans else: if params.has_contexts: - prefix += f"-context-score-{params.context_score}" + prefix += f"_context-score-{params.context_score}" return {prefix: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} def decode_dataset( @@ -707,46 +719,58 @@ def decode_dataset( return results -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -762,6 +786,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "greedy_search", "beam_search", @@ -783,9 +810,9 @@ def main(): params.has_contexts = False if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -794,20 +821,20 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search", "modified_beam_search_LODR", @@ -815,19 +842,19 @@ def main(): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" if "LODR" in params.decoding_method: params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1038,12 +1065,19 @@ def main(): ngram_lm_scale=ngram_lm_scale, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 360523b8eb..ebcafbf873 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -43,7 +43,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet +from lhotse import CutSet, set_caching_enabled from streaming_beam_search import ( fast_beam_search_one_best, greedy_search, @@ -76,6 +76,13 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + parser.add_argument( "--epoch", type=int, @@ -188,6 +195,14 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + + add_model_arguments(parser) return parser @@ -640,46 +655,60 @@ def decode_dataset( return {key: decode_results} -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = ( + recogs_filename = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_filename, "w") as f: + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( + + wer_filename = ( params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -694,6 +723,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: @@ -706,18 +738,21 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.label: + params.suffix += f"-{params.label}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -845,12 +880,21 @@ def main(): decoding_graph=decoding_graph, ) - save_results( + + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!")