From 109354b6b8199fa27cd8d4310b59a2e45da1d537 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 2 Oct 2023 14:00:06 +0800 Subject: [PATCH] Add CTC HLG decoding for zipformer (#1287) --- ...onformer-ctc.sh => run-pre-trained-ctc.sh} | 60 ++- ...nformer-ctc.yml => run-pretrained-ctc.yml} | 10 +- .../jit_pretrained_decode_with_H.py | 4 +- .../jit_pretrained_decode_with_HL.py | 8 +- .../jit_pretrained_decode_with_HLG.py | 8 +- .../ASR/zipformer/export-onnx-ctc.py | 436 ++++++++++++++++++ .../ASR/zipformer/onnx_pretrained_ctc.py | 213 +++++++++ .../ASR/zipformer/onnx_pretrained_ctc_H.py | 277 +++++++++++ .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 275 +++++++++++ .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 275 +++++++++++ 10 files changed, 1545 insertions(+), 21 deletions(-) rename .github/scripts/{run-pre-trained-conformer-ctc.sh => run-pre-trained-ctc.sh} (79%) rename .github/workflows/{run-pretrained-conformer-ctc.yml => run-pretrained-ctc.yml} (91%) create mode 100755 egs/librispeech/ASR/zipformer/export-onnx-ctc.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-ctc.sh similarity index 79% rename from .github/scripts/run-pre-trained-conformer-ctc.sh rename to .github/scripts/run-pre-trained-ctc.sh index ea400c628a..7d6449c9aa 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-ctc.sh @@ -10,7 +10,57 @@ log() { pushd egs/librispeech/ASR -# repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +log "CTC greedy search" + +./zipformer/onnx_pretrained_ctc.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC H decoding" + +./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + --H $repo/H.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC HL decoding" + +./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HL $repo/HL.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC HLG decoding" + +./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HLG $repo/HLG.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +rm -rf $repo + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -128,7 +178,9 @@ repo=$(basename $repo_url) pushd $repo git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lm/G_3_gram_char.fst.txt" +git lfs pull --include "data/lang_char/H.fst" +git lfs pull --include "data/lang_char/HL.fst" +git lfs pull --include "data/lang_char/HLG.fst" popd @@ -153,10 +205,6 @@ popd ls -lh $repo/exp -log "Generating H.fst, HL.fst" - -./local/prepare_lang_fst.py --lang-dir $repo/data/lang_char --ngram-G $repo/data/lm/G_3_gram_char.fst.txt - ls -lh $repo/data/lang_char log "Decoding with H on CPU with OpenFst" diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-ctc.yml similarity index 91% rename from .github/workflows/run-pretrained-conformer-ctc.yml rename to .github/workflows/run-pretrained-ctc.yml index 54845159d1..074a63dfc9 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-ctc.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-pre-trained-conformer-ctc +name: run-pre-trained-ctc on: push: @@ -31,12 +31,12 @@ on: default: 'y' concurrency: - group: run_pre_trained_conformer_ctc-${{ github.ref }} + group: run_pre_trained_ctc-${{ github.ref }} cancel-in-progress: true jobs: - run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' + run_pre_trained_ctc: + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc' runs-on: ${{ matrix.os }} strategy: matrix: @@ -84,4 +84,4 @@ jobs: export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-conformer-ctc.sh + .github/scripts/run-pre-trained-ctc.sh diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py index 8dd856a4ef..4bdec9e114 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -145,7 +145,7 @@ def decode( decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -157,7 +157,7 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] # tokens are incremented during graph construction diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index 796e196613..d5a1dba3ca 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -132,8 +132,8 @@ def decode( contains output from log_softmax. HL: The HL graph. - word2token: - A map mapping token ID to word string. + id2word: + A map mapping word ID to word string. Returns: Return a list of decoded words. """ @@ -145,7 +145,7 @@ def decode( decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -157,7 +157,7 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] # are shifted by 1 during graph construction diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py index 0024d5c9cd..216677a230 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -131,8 +131,8 @@ def decode( contains output from log_softmax. HLG: The HLG graph. - word2token: - A map mapping token ID to word string. + id2word: + A map mapping word ID to word string. Returns: Return a list of decoded words. """ @@ -144,7 +144,7 @@ def decode( decoder.decode(decodable) if not decoder.reached_final(): - print(f"failed to decode {filename}") + logging.info(f"failed to decode {filename}") return [""] ok, best_path = decoder.get_best_path() @@ -156,7 +156,7 @@ def decode( total_weight, ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: - print(f"failed to get linear symbol sequence for {filename}") + logging.info(f"failed to get linear symbol sequence for {filename}") return [""] # are shifted by 1 during graph construction diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py new file mode 100755 index 0000000000..3345d20d3f --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a CTC model from PyTorch to ONNX. + +Note that the model is trained using both transducer and CTC loss. This script +exports only the CTC head. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-ctc.py \ + --use-transducer 0 \ + --use-ctc 1 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size 16 \ + --left-context-frames 128 + +It will generate the following 2 files inside $repo/exp: + + - model.onnx + - model.int8.onnx + +See ./onnx_pretrained_ctc.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" + + def __init__( + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + ctc_output: nn.Module, + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_embed: + The first downsampling layer for zipformer. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.ctc_output = ctc_output + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - log_probs, a 3-D tensor of shape (N, T', vocab_size) + - log_probs_len, a 1-D int64 tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + log_probs = self.ctc_output(encoder_out) + + return log_probs, log_probs_len + + +def export_ctc_model_onnx( + model: OnnxModel, + filename: str, + opset_version: int = 11, +) -> None: + """Export the given model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - log_probs, a tensor of shape (N, T', joiner_dim) + - log_probs_len, a tensor of shape (N,) + + Args: + model: + The input model + filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + model = torch.jit.trace(model, (x, x_lens)) + + torch.onnx.export( + model, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["log_probs", "log_probs_len"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "log_probs": {0: "N", 1: "T"}, + "log_probs_len": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2_ctc", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2 CTC", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + model = OnnxModel( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + ctc_output=model.ctc_output, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"num parameters: {num_param}") + + opset_version = 13 + + logging.info("Exporting ctc model") + filename = params.exp_dir / f"model.onnx" + export_ctc_model_onnx( + model, + filename, + opset_version=opset_version, + ) + logging.info(f"Exported to {filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + filename_int8 = params.exp_dir / f"model.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 100755 index 0000000000..eb5cee9cd5 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + blank_id = 0 + s = "\n" + for i in range(log_probs.size(0)): + # greedy search + indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1) + token_ids = torch.unique_consecutive(indexes) + + token_ids = token_ids[token_ids != blank_id] + words = token_ids_to_words(token_ids.tolist()) + s += f"{args.sound_files[i]}:\n{words}\n\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 100755 index 0000000000..683a7dc20e --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + --H /path/to/H.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--H", + type=str, + help="""Path to H.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2word: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if i != 1] + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + H=H, + id2token=token_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 100755 index 0000000000..0b94bfa653 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HL /path/to/HL.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HL", + type=str, + help="""Path to HL.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HL=HL, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 100755 index 0000000000..93569142ab --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HLG /path/to/HLG.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HLG=HLG, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main()