diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py index de06528932..e3961736c4 100755 --- a/egs/audioset/AT/zipformer/pretrained.py +++ b/egs/audioset/AT/zipformer/pretrained.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,7 +21,6 @@ Note: This is a example for librispeech dataset, if you are using different dataset, you should change the argument values according to your dataset. -- For non-streaming model: ./zipformer/export.py \ --exp-dir ./zipformer/exp \ @@ -29,75 +28,10 @@ --epoch 30 \ --avg 9 -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - Usage of this script: -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search ./zipformer/pretrained.py \ --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -109,6 +43,7 @@ import argparse +import csv import logging import math from typing import List @@ -117,11 +52,6 @@ import kaldifeat import torch import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) from export import num_tokens from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params @@ -144,20 +74,9 @@ def get_parser(): ) parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", + "--label-dict", type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, + help="""class_labels_indices.csv.""", ) parser.add_argument( @@ -177,55 +96,6 @@ def get_parser(): help="The sample rate of the input sound file", ) - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - add_model_arguments(parser) return parser @@ -263,12 +133,6 @@ def main(): params.update(vars(args)) - token_table = k2.SymbolTable.from_file(params.tokens) - - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - logging.info(f"{params}") device = torch.device("cpu") @@ -277,14 +141,6 @@ def main(): logging.info(f"device: {device}") - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - logging.info("Creating model") model = get_model(params) @@ -296,6 +152,15 @@ def main(): model.to(device) model.eval() + # get the label dictionary + label_dict = {} + with open(params.label_dict, "r") as f: + reader = csv.reader(f, delimiter=",") + for i, row in enumerate(reader): + if i == 0: + continue + label_dict[int(row[0])] = row[2] + logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -320,57 +185,15 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - # model forward + # model forward and predict the audio events encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + logits = model.forward_audio_tagging(encoder_out, encoder_out_lens) - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) + results = [] + for i, logit in enumerate(logits): + topk_prob, topk_index = logit.sigmoid().topk(5) + topk_labels = [label_dict[index.item()] for index in topk_index] + print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}") logging.info("Decoding Done")