diff --git a/egs/aishell/ASR/zipformer/beam_search.py b/egs/aishell/ASR/zipformer/beam_search.py deleted file mode 120000 index 8554e44ccf..0000000000 --- a/egs/aishell/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer/beam_search.py b/egs/aishell/ASR/zipformer/beam_search.py new file mode 100644 index 0000000000..5b035abe7d --- /dev/null +++ b/egs/aishell/ASR/zipformer/beam_search.py @@ -0,0 +1,2952 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +from torch import nn + +from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost +from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer +from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) + + +def fast_beam_search_one_best( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + ilme_scale: float = 0.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ilme_scale=ilme_scale, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_LG( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + ilme_scale: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ilme_scale=ilme_scale, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + blank_penalty=blank_penalty, + temperature=temperature, + allow_partial=allow_partial, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_oracle( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, + allow_partial: bool = False, + blank_penalty: float = 0.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) + + if ilme_scale != 0: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + if blank_penalty != 0: + ilme_logits[:, 0] -= blank_penalty + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs + + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output( + encoder_out_lens.tolist(), allow_partial=allow_partial + ) + + return lattice + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + y = logits.argmax().item() + if y not in (blank_id, unk_id): + hyp.append(y) + timestamp.append(t) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + if not return_timestamps: + return hyp + else: + return DecodingResults( + hyps=[hyp], + timestamps=[timestamp], + ) + + +def greedy_search_batch( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_penalty: float = 0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[-1] * context_size + [blank_id] for _ in range(N)] + + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(N)] + + decoder_input = torch.tensor(hyps, device=device, dtype=torch.int64)[ + :, -context_size: + ] # (N, context_size) + + k = torch.zeros(N, 1, device=device, dtype=torch.int64) + + decoder_out = model.decoder(decoder_input, k, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + timestamps[i].append(t) + scores[i].append(logits[i, v].item()) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + k = torch.where( + torch.tensor( + [h[-context_size - 1] for h in hyps[:batch_size]], + device=device, + dtype=torch.int64, + ) + == decoder_input[:, -1], + k + 1, + torch.zeros(N, 1, device=device, dtype=torch.int64), + ) + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, k, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size + 1 :] for h in hyps] + ans = [] + ans_timestamps = [] + ans_scores = [] + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) + ans_scores.append(scores[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + scores=ans_scores, + ) + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state + state_cost: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_graph: Optional[ContextGraph] = None, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=None if context_graph is None else context_graph.root, + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) + + +def modified_beam_search_lm_rescore( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + # get the best hyp with different lm_scale + for lm_scale in lm_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}" + tot_scores = am_scores.values + lm_scores * lm_scale + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def modified_beam_search_lm_rescore_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def _deprecated_modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] + new_token = topk_token_indexes[i] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + # TODO(fangjun): Scale the blank posterior + log_prob = (logits / temperature).log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i in (blank_id, unk_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def fast_beam_search_with_nbest_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def modified_beam_search_ngram_rescoring( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LODR_lm: NgramLm, + LODR_lm_scale: float, + LM: LmScorer, + beam: int = 4, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, # state of the NN LM + lm_score=init_score.reshape(-1), + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram + context_state=None if context_graph is None else context_graph.root, + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + LM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + # forward NN LM to get new states and scores + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + + context_score + ) # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_lm_shallow_fusion( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/aishell/ASR/zipformer_distinctK/__init__.py b/egs/aishell/ASR/zipformer_distinctK/__init__.py new file mode 120000 index 0000000000..87c684e002 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/__init__.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/aishell/ASR/zipformer/__init__.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/asr_datamodule.py b/egs/aishell/ASR/zipformer_distinctK/asr_datamodule.py new file mode 120000 index 0000000000..ec900d29c7 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/asr_datamodule.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/beam_search.py b/egs/aishell/ASR/zipformer_distinctK/beam_search.py new file mode 120000 index 0000000000..88c1c3f5db --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/beam_search.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/decode.py b/egs/aishell/ASR/zipformer_distinctK/decode.py new file mode 100755 index 0000000000..1968904aea --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/decode.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer_distinctK/decode_stream.py b/egs/aishell/ASR/zipformer_distinctK/decode_stream.py new file mode 120000 index 0000000000..4069916e80 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/decode_stream.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/decoder.py b/egs/aishell/ASR/zipformer_distinctK/decoder.py new file mode 120000 index 0000000000..7d3b7e87a6 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/decoder.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/encoder_interface.py b/egs/aishell/ASR/zipformer_distinctK/encoder_interface.py new file mode 120000 index 0000000000..051ceecbd0 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/encoder_interface.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/export-onnx-streaming.py b/egs/aishell/ASR/zipformer_distinctK/export-onnx-streaming.py new file mode 120000 index 0000000000..7e7c35494f --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/export-onnx-streaming.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/export-onnx.py b/egs/aishell/ASR/zipformer_distinctK/export-onnx.py new file mode 120000 index 0000000000..9b41d235fc --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/export-onnx.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/export.py b/egs/aishell/ASR/zipformer_distinctK/export.py new file mode 120000 index 0000000000..5d07dcc848 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/export.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/jit_pretrained.py b/egs/aishell/ASR/zipformer_distinctK/jit_pretrained.py new file mode 120000 index 0000000000..bb3c9e4c5f --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/jit_pretrained.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/jit_pretrained_streaming.py b/egs/aishell/ASR/zipformer_distinctK/jit_pretrained_streaming.py new file mode 120000 index 0000000000..ad64abb15c --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/jit_pretrained_streaming.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/joiner.py b/egs/aishell/ASR/zipformer_distinctK/joiner.py new file mode 120000 index 0000000000..737aaa7b1a --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/joiner.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/model.py b/egs/aishell/ASR/zipformer_distinctK/model.py new file mode 120000 index 0000000000..d733d7f10c --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/model.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/onnx_check.py b/egs/aishell/ASR/zipformer_distinctK/onnx_check.py new file mode 120000 index 0000000000..0c6c54bba7 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/onnx_check.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/onnx_decode.py b/egs/aishell/ASR/zipformer_distinctK/onnx_decode.py new file mode 120000 index 0000000000..57df3da26d --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/onnx_decode.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/aishell/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/onnx_pretrained-streaming.py b/egs/aishell/ASR/zipformer_distinctK/onnx_pretrained-streaming.py new file mode 120000 index 0000000000..a13ef0e163 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/onnx_pretrained.py b/egs/aishell/ASR/zipformer_distinctK/onnx_pretrained.py new file mode 120000 index 0000000000..3f7fa6592f --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/onnx_pretrained.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/optim.py b/egs/aishell/ASR/zipformer_distinctK/optim.py new file mode 120000 index 0000000000..514c8e7bfa --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/optim.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/pretrained.py b/egs/aishell/ASR/zipformer_distinctK/pretrained.py new file mode 120000 index 0000000000..241bd92eb9 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/pretrained.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/scaling.py b/egs/aishell/ASR/zipformer_distinctK/scaling.py new file mode 120000 index 0000000000..4346db1906 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/scaling.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/scaling_converter.py b/egs/aishell/ASR/zipformer_distinctK/scaling_converter.py new file mode 120000 index 0000000000..b763511586 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/scaling_converter.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/streaming_beam_search.py b/egs/aishell/ASR/zipformer_distinctK/streaming_beam_search.py new file mode 120000 index 0000000000..d8dad610e9 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/streaming_beam_search.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/streaming_decode.py b/egs/aishell/ASR/zipformer_distinctK/streaming_decode.py new file mode 100755 index 0000000000..c3820447ac --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/streaming_decode.py @@ -0,0 +1,880 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +from asr_datamodule import AishellAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="Path to the lang dir(containing lexicon, tokens, etc.)", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + encoder_out=encoder_out, + streams=decode_streams, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + The Lexicon. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + if audio.max() > 1: + logging.warning( + f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." + f"Clipping to [-1, 1]." + ) + audio = np.clip(audio, -1, 1) + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + list(decode_streams[i].ground_truth.strip()), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + [ + lexicon.token_table[idx] + for idx in decode_streams[i].decoding_result() + ], + ) + ) + del decode_streams[i] + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + key = f"greedy_search_{key}" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_{key}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}_{key}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + + dev_cuts = aishell.valid_cuts() + test_cuts = aishell.test_cuts() + + test_sets = ["dev", "test"] + test_cuts = [dev_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer_distinctK/subsampling.py b/egs/aishell/ASR/zipformer_distinctK/subsampling.py new file mode 120000 index 0000000000..333b5b72e7 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/subsampling.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/aishell/ASR/zipformer_distinctK/train.py b/egs/aishell/ASR/zipformer_distinctK/train.py new file mode 100755 index 0000000000..d381649e42 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/train.py @@ -0,0 +1,1350 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --exp-dir zipformer/exp \ + --training-subset L + --lr-epochs 1.5 \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --training-subset L \ + --lr-epochs 1.5 \ + --max-duration 750 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + valid_cuts = aishell.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 12.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer_distinctK/zipformer.py b/egs/aishell/ASR/zipformer_distinctK/zipformer.py new file mode 120000 index 0000000000..7686de1592 --- /dev/null +++ b/egs/aishell/ASR/zipformer_distinctK/zipformer.py @@ -0,0 +1 @@ +/Users/yifanyeung/Downloads/icefall/egs/librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7fcd242fcd..1d79eef9d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1,5 +1,4 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -16,52 +15,36 @@ # limitations under the License. import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Dict, List, Optional import k2 -import sentencepiece as spm import torch -from torch import nn +from model import Transducer -from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding -from icefall.lm_wrapper import LmScorer -from icefall.rnn_lm.model import RnnLmModel -from icefall.transformer_lm.model import TransformerLM -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) +from icefall.utils import get_texts def fast_beam_search_one_best( - model: nn.Module, + model: Transducer, decoding_graph: k2.Fsa, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, beam: float, max_states: int, max_contexts: int, - temperature: float = 1.0, - ilme_scale: float = 0.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, and then + A lattice is first obtained using modified beam search, and then the shortest path within the lattice is used as the final output. Args: model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. + Decoding graph used for decoding, may be a TrivialGraph or a HLG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -73,14 +56,8 @@ def fast_beam_search_one_best( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -90,248 +67,15 @@ def fast_beam_search_one_best( beam=beam, max_states=max_states, max_contexts=max_contexts, - temperature=temperature, - ilme_scale=ilme_scale, - allow_partial=allow_partial, - blank_penalty=blank_penalty, ) best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - blank_penalty: float = 0.0, - ilme_scale: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ilme_scale=ilme_scale, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - blank_penalty=blank_penalty, - temperature=temperature, - allow_partial=allow_partial, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search_nbest_oracle( - model: nn.Module, + model: Transducer, decoding_graph: k2.Fsa, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, @@ -342,14 +86,10 @@ def fast_beam_search_nbest_oracle( ref_texts: List[List[int]], use_double_scores: bool = True, nbest_scale: float = 0.5, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, and then + A lattice is first obtained using modified beam search, and then we select `num_paths` linear paths from the lattice. The path that has the minimum edit distance with the given reference transcript is used as the output. @@ -361,7 +101,7 @@ def fast_beam_search_nbest_oracle( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. + Decoding graph used for decoding, may be a TrivialGraph or a HLG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -385,14 +125,9 @@ def fast_beam_search_nbest_oracle( nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. + Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ lattice = fast_beam_search( model=model, @@ -402,9 +137,6 @@ def fast_beam_search_nbest_oracle( beam=beam, max_states=max_states, max_contexts=max_contexts, - temperature=temperature, - allow_partial=allow_partial, - blank_penalty=blank_penalty, ) nbest = Nbest.from_lattice( @@ -433,25 +165,18 @@ def fast_beam_search_nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) + hyps = get_texts(best_path) + return hyps def fast_beam_search( - model: nn.Module, + model: Transducer, decoding_graph: k2.Fsa, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, beam: float, max_states: int, max_contexts: int, - temperature: float = 1.0, - subtract_ilme: bool = False, - ilme_scale: float = 0.1, - allow_partial: bool = False, - blank_penalty: float = 0.0, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -459,7 +184,7 @@ def fast_beam_search( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. + Decoding graph used for decoding, may be a TrivialGraph or a HLG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -471,8 +196,6 @@ def fast_beam_search( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. - temperature: - Softmax temperature. Returns: Return an FsaVec with axes [utt][state][arc] containing the decoded lattice. Note: When the input graph is a TrivialGraph, the returned @@ -497,7 +220,8 @@ def fast_beam_search( individual_streams.append(k2.RnntDecodingStream(decoding_graph)) decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - encoder_out = model.joiner.encoder_proj(encoder_out) + encoder_out_len = torch.ones(1, dtype=torch.int32) + decoder_out_len = torch.ones(1, dtype=torch.int32) for t in range(T): # shape is a RaggedShape of shape (B, context) @@ -507,7 +231,6 @@ def fast_beam_search( contexts = contexts.to(torch.int64) # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) # current_encoder_out is of shape # (shape.NumElements(), 1, joiner_dim) # fmt: off @@ -516,47 +239,22 @@ def fast_beam_search( ) # fmt: on logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - log_probs = (logits / temperature).log_softmax(dim=-1) - - if ilme_scale != 0: - ilme_logits = model.joiner( - torch.zeros_like( - current_encoder_out, device=current_encoder_out.device - ).unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - ilme_logits = ilme_logits.squeeze(1).squeeze(1) - if blank_penalty != 0: - ilme_logits[:, 0] -= blank_penalty - ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) - log_probs -= ilme_scale * ilme_log_probs - + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) # (N, vocab_size) + log_probs = logits.log_softmax(dim=-1) decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output( - encoder_out_lens.tolist(), allow_partial=allow_partial - ) + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) return lattice def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: """Greedy search for a single utterance. Args: model: @@ -566,12 +264,8 @@ def greedy_search( max_sym_per_frame: Maximum number of symbols per frame. If it is set to 0, the WER would be 100%. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 @@ -580,27 +274,19 @@ def greedy_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - device = next(model.parameters()).device + device = model.device decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) T = encoder_out.size(1) t = 0 hyp = [blank_id] * context_size - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - # Maximum symbols per utterance. max_sym_per_utt = 1000 @@ -610,6 +296,9 @@ def greedy_search( # symbols per utterance decoded so far sym_per_utt = 0 + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + while t < T and sym_per_utt < max_sym_per_utt: if sym_per_frame >= max_sym_per_frame: sym_per_frame = 0 @@ -617,26 +306,21 @@ def greedy_search( continue # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False + current_encoder_out, decoder_out, encoder_out_len, decoder_out_len ) - # logits is (1, 1, 1, vocab_size) - - if blank_penalty != 0: - logits[:, :, :, 0] -= blank_penalty + # logits is (1, vocab_size) y = logits.argmax().item() - if y not in (blank_id, unk_id): + if y != blank_id: hyp.append(y) - timestamp.append(t) decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( 1, context_size ) decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) sym_per_utt += 1 sym_per_frame += 1 @@ -645,22 +329,14 @@ def greedy_search( t += 1 hyp = hyp[context_size:] # remove blanks - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) + return hyp def greedy_search_batch( - model: nn.Module, + model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - blank_penalty: float = 0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: +) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: model: @@ -670,12 +346,9 @@ def greedy_search_batch( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). """ assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -690,7 +363,6 @@ def greedy_search_batch( device = next(model.parameters()).device blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size batch_size_list = packed_encoder_out.batch_sizes.tolist() @@ -698,56 +370,47 @@ def greedy_search_batch( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - # scores[n][i] is the logits on which hyp[n][i] is decoded - scores = [[] for _ in range(N)] + hyps = [[blank_id] * context_size for _ in range(N)] decoder_input = torch.tensor( hyps, device=device, dtype=torch.int64, ) # (N, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) # decoder_out: (N, 1, decoder_out_dim) - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + encoder_out_len = torch.ones(1, dtype=torch.int32) + decoder_out_len = torch.ones(1, dtype=torch.int32) + + encoder_out = packed_encoder_out.data offset = 0 - for t, batch_size in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + current_encoder_out = current_encoder_out.unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) offset = end decoder_out = decoder_out[:batch_size] logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) + current_encoder_out, + decoder_out, + encoder_out_len.expand(batch_size), + decoder_out_len.expand(batch_size), + ) # (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) assert logits.ndim == 2, logits.shape - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): - if v not in (blank_id, unk_id): + if v != blank_id: hyps[i].append(v) - timestamps[i].append(t) - scores[i].append(logits[i, v].item()) emitted = True + if emitted: # update decoder output decoder_input = [h[-context_size:] for h in hyps[:batch_size]] @@ -755,28 +418,19 @@ def greedy_search_batch( decoder_input, device=device, dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) + ) # (batch_size, context_size) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) # (batch_size, 1, decoder_out_dim) sorted_ans = [h[context_size:] for h in hyps] ans = [] - ans_timestamps = [] - ans_scores = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - ans_scores.append(scores[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - scores=ans_scores, - ) + return ans @dataclass @@ -789,22 +443,6 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - # Context graph state - context_state: Optional[ContextState] = None - @property def key(self) -> str: """Return a string representation of self.ys""" @@ -892,22 +530,11 @@ def filter(self, threshold: torch.Tensor) -> "HypothesisList": ans.add(hyp) # shallow copy return ans - def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": - """Return the top-k hypothesis. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - """ + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" hyps = list(self._data.items()) - if length_norm: - hyps = sorted( - hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True - )[:k] - else: - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans @@ -928,7 +555,96 @@ def __str__(self) -> str: return ", ".join(s) -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: +def run_decoder( + ys: List[int], + model: Transducer, + decoder_cache: Dict[str, torch.Tensor], +) -> torch.Tensor: + """Run the neural decoder model for a given hypothesis. + + Args: + ys: + The current hypothesis. + model: + The transducer model. + decoder_cache: + Cache to save computations. + Returns: + Return a 1-D tensor of shape (decoder_out_dim,) containing + output of `model.decoder`. + """ + context_size = model.decoder.context_size + key = "_".join(map(str, ys[-context_size:])) + if key in decoder_cache: + return decoder_cache[key] + + device = model.device + + decoder_input = torch.tensor( + [ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[key] = decoder_out + + return decoder_out + + +def run_joiner( + key: str, + model: Transducer, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, + joint_cache: Dict[str, torch.Tensor], +): + """Run the joint network given outputs from the encoder and decoder. + + Args: + key: + A key into the `joint_cache`. + model: + The transducer model. + encoder_out: + A tensor of shape (1, 1, encoder_out_dim). + decoder_out: + A tensor of shape (1, 1, decoder_out_dim). + encoder_out_len: + A tensor with value [1]. + decoder_out_len: + A tensor with value [1]. + joint_cache: + A dict to save computations. + Returns: + Return a tensor from the output of log-softmax. + Its shape is (vocab_size,). + """ + if key in joint_cache: + return joint_cache[key] + + logits = model.joiner( + encoder_out, + decoder_out, + encoder_out_len, + decoder_out_len, + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + + joint_cache[key] = log_prob + + return log_prob + + +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: """Return a ragged shape with axes [utt][num_hyps]. Args: @@ -954,16 +670,12 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def modified_beam_search( - model: nn.Module, + model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - context_graph: Optional[ContextGraph] = None, beam: int = 4, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcodded. Args: model: @@ -975,14 +687,9 @@ def modified_beam_search( encoder_out before padding. beam: Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) @@ -995,7 +702,6 @@ def modified_beam_search( ) blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device @@ -1008,32 +714,31 @@ def modified_beam_search( for i in range(N): B[i].add( Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], + ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=None if context_graph is None else context_graph.root, - timestamp=[], ) ) - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + encoder_out = packed_encoder_out.data offset = 0 finalized_B = [] - for t, batch_size in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + current_encoder_out = current_encoder_out.unsqueeze(1) + # current_encoder_out's shape is: (batch_size, 1, encoder_out_dim) offset = end finalized_B = B[batch_size:] + finalized_B B = B[:batch_size] - hyps_shape = get_hyps_shape(B).to(device) + hyps_shape = _get_hyps_shape(B).to(device) A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] ys_log_probs = torch.cat( @@ -1046,9 +751,8 @@ def modified_beam_search( dtype=torch.int64, ) # (num_hyps, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # as index, so we use `to(torch.int64)` below. @@ -1056,20 +760,17 @@ def modified_beam_search( current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) + ) # (num_hyps, 1, encoder_out_dim) logits = model.joiner( current_encoder_out, decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1094,540 +795,76 @@ def modified_beam_search( for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state - if new_token not in (blank_id, unk_id): + if new_token != blank_id: new_ys.append(new_token) - new_timestamp.append(t) - if context_graph is not None: - ( - context_score, - new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) - - new_log_prob = topk_log_probs[k] + context_score - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - context_state=new_context_state, - ) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B[i].add(new_hyp) B = B + finalized_B - - # finalize context_state, if the matched contexts do not reach final state - # we need to add the score on the corresponding backoff arc - if context_graph is not None: - finalized_B = [HypothesisList() for _ in range(len(B))] - for i, hyps in enumerate(B): - for hyp in list(hyps): - context_score, new_context_state = context_graph.finalize( - hyp.context_state - ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + context_score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - B = finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] ans = [] - ans_timestamps = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) + return ans -def modified_beam_search_lm_rescore( - model: nn.Module, +def _deprecated_modified_beam_search( + model: Transducer, encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - lm_scale_list: List[int], beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. Args: model: - The transducer model. + An instance of `Transducer`. encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. + Beam size. Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) + assert encoder_out.ndim == 3 + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - # get the best hyp with different lm_scale - for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}" - tot_scores = am_scores.values + lm_scores * lm_scale - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def modified_beam_search_lm_rescore_LODR( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - LODR_lm: NgramLm, - sp: spm.SentencePieceProcessor, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - # now LODR scores - import math - - LODR_scores = [] - for seq in candidate_seqs: - tokens = " ".join(sp.id_to_piece(seq)) - LODR_scores.append(LODR_lm.score(tokens)) - LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( - 10 - ) # arpa scores are 10-based - assert lm_scores.shape == LODR_scores.shape - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - LODR_scale_list = [0.05 * i for i in range(1, 20)] - # get the best hyp with different lm_scale and lodr_scale - for lm_scale in lm_scale_list: - for lodr_scale in LODR_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" - tot_scores = ( - am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale - ) - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def _deprecated_modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device + device = model.device T = encoder_out.size(1) B = HypothesisList() B.add( Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], + ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], ) ) - encoder_out = model.joiner.encoder_proj(encoder_out) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) for t in range(T): # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) # fmt: on A = list(B) B = HypothesisList() @@ -1638,27 +875,21 @@ def _deprecated_modified_beam_search( decoder_input = torch.tensor( [hyp.ys[-context_size:] for hyp in A], device=device, - dtype=torch.int64, ) # decoder_input is of shape (num_hyps, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) + current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) logits = model.joiner( current_encoder_out, decoder_out, - project_input=False, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) + # logits is of shape (num_hyps, vocab_size) log_probs = logits.log_softmax(dim=-1) log_probs.add_(ys_log_probs) @@ -1678,34 +909,24 @@ def _deprecated_modified_beam_search( for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): + if new_token != blank_id: new_ys.append(new_token) - new_timestamp.append(t) new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B.add(new_hyp) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + return ys def beam_search( - model: nn.Module, + model: Transducer, encoder_out: torch.Tensor, beam: int = 4, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: +) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -1718,36 +939,23 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. + Return the decoded result. """ assert encoder_out.ndim == 3 # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = next(model.parameters()).device + device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) T = encoder_out.size(1) t = 0 @@ -1755,7 +963,8 @@ def beam_search( B = HypothesisList() B.add( Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) @@ -1763,87 +972,58 @@ def beam_search( sym_per_utt = 0 + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + decoder_cache: Dict[str, torch.Tensor] = {} while t < T and sym_per_utt < max_sym_per_utt: # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on A = B B = HypothesisList() joint_cache: Dict[str, torch.Tensor] = {} - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - while True: y_star = A.get_most_probable() A.remove(y_star) - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - if blank_penalty != 0: - logits[:, :, :, 0] -= blank_penalty - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] + decoder_out = run_decoder( + ys=y_star.ys, model=model, decoder_cache=decoder_cache + ) + + key = "_".join(map(str, y_star.ys[-context_size:])) + key += f"-t-{t}" + log_prob = run_joiner( + key=key, + model=model, + encoder_out=current_encoder_out, + decoder_out=decoder_out, + encoder_out_len=encoder_out_len, + decoder_out_len=decoder_out_len, + joint_cache=joint_cache, + ) # First, process the blank symbol skip_log_prob = log_prob[blank_id] new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): + for idx in range(values.size(0)): + i = indices[idx].item() + if i == blank_id: continue + new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) + + new_log_prob = y_star.log_prob + values[idx] + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) # Check whether B contains more than "beam" elements more probable # than the most probable in A @@ -1859,1084 +1039,4 @@ def beam_search( best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def fast_beam_search_with_nbest_rescoring( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def modified_beam_search_ngram_rescoring( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - lm_scale = ngram_lm_scale - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_LODR( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LODR_lm: NgramLm, - LODR_lm_scale: float, - LM: LmScorer, - beam: int = 4, - context_graph: Optional[ContextGraph] = None, -) -> List[List[int]]: - """This function implements LODR (https://arxiv.org/abs/2203.16776) with - `modified_beam_search`. It uses a bi-gram language model as the estimate - of the internal language model and subtracts its score during shallow fusion - with an external language model. This implementation uses a RNNLM as the - external language model. - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - LODR_lm: - A low order n-gram LM, whose score will be subtracted during shallow fusion - LODR_lm_scale: - The scale of the LODR_lm - LM: - A neural net LM, e.g an RNNLM or transformer LM - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the NN LM - lm_score=init_score.reshape(-1), - state_cost=NgramLmStateCost( - LODR_lm - ), # state of the source domain ngram - context_state=None if context_graph is None else context_graph.root, - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - LM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - # forward NN LM to get new states and scores - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - # current score of hyp - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - - context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state - if new_token not in (blank_id, unk_id): - if context_graph is not None: - ( - context_score, - new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) - - ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - - # calculate the score of the latest token - current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score - - assert current_ngram_score <= 0.0, ( - state_cost.lm_score, - hyp.state_cost.lm_score, - ) - # score = score + TDLM_score - LODR_score - # LODR_LM_scale should be a negative number here - hyp_log_prob += ( - lm_score[new_token] * lm_scale - + LODR_lm_scale * current_ngram_score - + context_score - ) # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - else: - state_cost = hyp.state_cost - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - state_cost=state_cost, - context_state=new_context_state, - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # finalize context_state, if the matched contexts do not reach final state - # we need to add the score on the corresponding backoff arc - if context_graph is not None: - finalized_B = [HypothesisList() for _ in range(len(B))] - for i, hyps in enumerate(B): - for hyp in list(hyps): - context_score, new_context_state = context_graph.finalize( - hyp.context_state - ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + context_score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - B = finalized_B - - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_lm_shallow_fusion( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + NN LM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - LM (LmScorer): - A neural net LM, e.g RNN or Transformer - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - lm_scores = torch.cat( - [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - `LM` will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] # a list of list - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - ys.append(new_token) - new_timestamp.append(t) - - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) + return ys diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 7ce44495bf..fb1c2a9066 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -70,6 +70,16 @@ def __init__( prob=0.05, ) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) + self.blank_id = blank_id assert context_size >= 1, context_size @@ -85,26 +95,49 @@ def __init__( groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) else: # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` # when inference with torch.jit.script and context_size == 1 self.conv = nn.Identity() - self.balancer2 = nn.Identity() - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + self.repeat_param = nn.Parameter(torch.randn(decoder_dim)) + + def add_repeat_param( + self, + embedding_out: torch.Tensor, + k: torch.Tensor, + ) -> torch.Tensor: + """ + Add the repeat parameter to the embedding_out tensor. + + Args: + embedding_out: + A tensor of shape (N, U, decoder_dim). + k: + A tensor of shape (N, U). + Should be (N, S + 1) during training. + Should be (N, 1) during inference. + is_training: + Whether it is training. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param + + def forward( + self, + y: torch.Tensor, + k: torch.Tensor, + need_pad: bool = True + ) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U). + k: + A 2-D tensor, statistic given the context_size with respect to utt. + Should be (N, S + 1) during training. + Should be (N, 1) during inference. need_pad: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. @@ -128,7 +161,12 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - embedding_out = self.balancer2(embedding_out) + + embedding_out = self.add_repeat_param( + embedding_out=embedding_out, + k=k, + ) + embedding_out = F.relu(embedding_out) + embedding_out = self.balancer2(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2f86af47a..b945819c69 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -111,6 +111,32 @@ def __init__( nn.LogSoftmax(dim=-1), ) + def compute_k( + self, + y: torch.Tensor, + context_size: int = 2, + ) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + Returns: + Return a tensor of shape (N, U). + """ + y_shift = torch.nn.functional.pad( + y, (context_size, 0), mode="constant", value=-1 + )[:, :-context_size] + mask = y_shift != y + + T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device) + cummax_out = (T_arange * mask).cummax(dim=-1)[0] + k = T_arange - cummax_out + + return k + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -207,8 +233,11 @@ def forward_transducer( # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + # compute k + k = self.compute_k(sos_y_padded, context_size=self.decoder.context_size) + # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) + decoder_out = self.decoder(sos_y_padded, k) # Note: y does not start with SOS # y_padded : [B, S]