Skip to content

Commit

Permalink
add inference script with a pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 20, 2024
1 parent 1921692 commit 9c4db1b
Showing 1 changed file with 20 additions and 197 deletions.
217 changes: 20 additions & 197 deletions egs/audioset/AT/zipformer/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand All @@ -21,83 +21,17 @@
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script:
- For non-streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
- For streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
Expand All @@ -109,6 +43,7 @@


import argparse
import csv
import logging
import math
from typing import List
Expand All @@ -117,11 +52,6 @@
import kaldifeat
import torch
import torchaudio
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
Expand All @@ -144,20 +74,9 @@ def get_parser():
)

parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)

parser.add_argument(
"--method",
"--label-dict",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
help="""class_labels_indices.csv.""",
)

parser.add_argument(
Expand All @@ -177,55 +96,6 @@ def get_parser():
help="The sample rate of the input sound file",
)

parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)

parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)

parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)

parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)

parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)

parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -263,12 +133,6 @@ def main():

params.update(vars(args))

token_table = k2.SymbolTable.from_file(params.tokens)

params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(f"{params}")

device = torch.device("cpu")
Expand All @@ -277,14 +141,6 @@ def main():

logging.info(f"device: {device}")

if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."

logging.info("Creating model")
model = get_model(params)

Expand All @@ -296,6 +152,15 @@ def main():
model.to(device)
model.eval()

# get the label dictionary
label_dict = {}
with open(params.label_dict, "r") as f:
reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
label_dict[int(row[0])] = row[2]

logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
Expand All @@ -320,57 +185,15 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)

# model forward
# model forward and predict the audio events
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)

hyps = []
msg = f"Using {params.method}"
logging.info(msg)

def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("▁", " ").strip()

if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)

for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(f"Unsupported method: {params.method}")

s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
results = []
for i, logit in enumerate(logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}")

logging.info("Decoding Done")

Expand Down

0 comments on commit 9c4db1b

Please sign in to comment.