Skip to content

Commit

Permalink
Add documents to ctc prefix beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Oct 8, 2024
1 parent 33fa9e8 commit 3a40c07
Showing 1 changed file with 144 additions and 18 deletions.
162 changes: 144 additions & 18 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,10 +1513,12 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`.
ys: List[int] = field(default_factory=list)

# The log prob of ys.
# The log prob of ys that ends with blank token.
# It contains only one entry.
log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32)

# The log prob of ys that ends with non blank token.
# It contains only one entry.
log_prob_non_blank: torch.Tensor = torch.tensor(
[float("-inf")], dtype=torch.float32
)
Expand All @@ -1526,16 +1528,18 @@ class Hypothesis:
timestamp: List[int] = field(default_factory=list)

# The lm score of ys
# May contain external LM score (including LODR score) and contextual biasing score
# It contains only one entry
lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32)

# the lm log_probs for next token given the history ys
# The number of elements should be equal to vocabulary size.
lm_log_probs: Optional[torch.Tensor] = None

# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None

# N-gram LM state
# LODR (N-gram LM) state
LODR_state: Optional[NgramLmStateCost] = None

# N-gram LM state
Expand All @@ -1544,10 +1548,12 @@ class Hypothesis:
# Context graph state
context_state: Optional[ContextState] = None

# This is the total score of current path, acoustic plus external LM score.
@property
def tot_score(self) -> torch.Tensor:
return self.log_prob + self.lm_score

# This is only the probability from model output (i.e External LM score not included).
@property
def log_prob(self) -> torch.Tensor:
return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank)
Expand Down Expand Up @@ -1614,14 +1620,14 @@ def add(self, hyp: Hypothesis) -> None:

def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
the largest `tot_score`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
If True, the `tot_score` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
Return the hypothesis that has the largest `tot_score`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys))
Expand All @@ -1645,14 +1651,14 @@ def remove(self, hyp: Hypothesis) -> None:
del self._data[key]

def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
"""Remove all Hypotheses whose tot_score is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
with `tot_score` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
Expand All @@ -1665,7 +1671,7 @@ def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
If True, the `tot_score` of a hypothesis is normalized by the
number of tokens in it.
"""
hyps = list(self._data.items())
Expand Down Expand Up @@ -1725,15 +1731,39 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:


def _step_worker(
log_probs,
indexes,
B,
beam,
blank_id,
log_probs: torch.Tensor,
indexes: torch.Tensor,
B: HypothesisList,
beam: int = 4,
blank_id: int = 0,
lm_scale: float = 0,
LODR_lm_scale: float = 0,
context_graph: Optional[ContextGraph] = None,
):
) -> HypothesisList:
"""The worker to decode one step.
Args:
log_probs:
topk log_probs of current step (i.e. the kept tokens of first pass pruning),
the shape is (beam,)
topk_indexes:
The indexes of the topk_values above, the shape is (beam,)
B:
An instance of HypothesisList containing the kept hypothesis.
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
lm_scale:
The scale of nn lm.
LODR_lm_scale:
The scale of the LODR_lm
context_graph:
A ContextGraph instance containing contextual phrases.
Return:
Returns the updated HypothesisList.
"""
A = list(B)
B = HypothesisList()
for h in range(len(A)):
Expand Down Expand Up @@ -1812,7 +1842,34 @@ def _step_worker(
return B


def _batch_worker(topk_values, topk_indexes, B, encoder_out_lens, beam, blank_id):
def _sequence_worker(
topk_values: torch.Tensor,
topk_indexes: torch.Tensor,
B: HypothesisList,
encoder_out_lens: torch.Tensor,
beam: int = 4,
blank_id: int = 0,
) -> HypothesisList:
"""The worker to decode one sequence.
Args:
topk_values:
topk log_probs of model output (i.e. the kept tokens of first pass pruning),
the shape is (T, beam)
topk_indexes:
The indexes of the topk_values above, the shape is (T, beam)
B:
An instance of HypothesisList containing the kept hypothesis.
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
Return:
Returns the updated HypothesisList.
"""
B.add(Hypothesis())
for j in range(encoder_out_lens):
log_probs, indexes = topk_values[j], topk_indexes[j]
Expand All @@ -1828,6 +1885,24 @@ def ctc_prefix_beam_search(
process_pool: Optional[Pool] = None,
return_nbest: Optional[bool] = False,
) -> Union[List[List[int]], List[HypothesisList]]:
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks".
Args:
ctc_output:
The output of ctc head (log probability), the shape is (B, T, V)
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
process_pool:
The process pool for parallel decoding, if not provided, it will use all
you cpu cores by default.
return_nbest:
If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise.
"""
batch_size, num_frames, vocab_size = ctc_output.shape

# TODO: using a larger beam for first pass pruning
Expand All @@ -1850,7 +1925,7 @@ def ctc_prefix_beam_search(
blank_id,
)
)
async_results = pool.starmap_async(_batch_worker, arguments)
async_results = pool.starmap_async(_sequence_worker, arguments)
B = list(async_results.get())
if process_pool is None:
pool.close()
Expand All @@ -1872,6 +1947,32 @@ def ctc_prefix_beam_search_shallow_fussion(
LM: Optional[LmScorer] = None,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]:
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
nervous language model shallow fussion, it also supports contextual
biasing with a given grammar.
Args:
ctc_output:
The output of ctc head (log probability), the shape is (B, T, V)
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
LODR_lm:
A low order n-gram LM, whose score will be subtracted during shallow fusion
LODR_lm_scale:
The scale of the LODR_lm
LM:
A neural net LM, e.g an RNNLM or transformer LM
context_graph:
A ContextGraph instance containing contextual phrases.
Return:
Returns a list of list of decoded token ids.
"""
batch_size, num_frames, vocab_size = ctc_output.shape
# TODO: using a larger beam for first pass pruning
topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam)
Expand Down Expand Up @@ -1926,14 +2027,14 @@ def ctc_prefix_beam_search_shallow_fussion(
)
if LM is None:
continue
# update lm_score
# update lm_log_probs
token_list = [] # a list of list
hs = []
cs = []
indexes = [] # (batch_idx, key)
for batch_idx, hyps in enumerate(B):
for hyp in hyps:
if hyp.lm_log_probs is None:
if hyp.lm_log_probs is None: # those hyps that prefix changes
if LM.lm_type == "rnn":
token_list.append([hyp.ys[-1]])
# store the LSTM states
Expand Down Expand Up @@ -2000,7 +2101,32 @@ def ctc_prefix_beam_search_attention_decoder_rescoring(
beam: int = 8,
blank_id: int = 0,
attention_scale: Optional[float] = None,
process_pool: Optional[Pool] = None,
):
"""Implement prefix search decoding in "Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add
attention decoder rescoring.
Args:
ctc_output:
The output of ctc head (log probability), the shape is (B, T, V)
attention_decoder:
The attention decoder.
encoder_out:
The output of encoder, the shape is (B, T, D)
encoder_out_lens:
The lengths (frames) of sequences after subsampling, the shape is (B,)
beam:
The number of hypothesis to be kept at each step.
blank_id:
The id of blank in the vocabulary.
attention_scale:
The scale of attention decoder score, if not provided it will search in
a default list (see the code below).
process_pool:
The process pool for parallel decoding, if not provided, it will use all
you cpu cores by default.
"""
# List[HypothesisList]
nbest = ctc_prefix_beam_search(
ctc_output=ctc_output,
Expand Down

0 comments on commit 3a40c07

Please sign in to comment.