Skip to content

Commit

Permalink
Merge pull request #1 from AI4Bharat/multi-softmax
Browse files Browse the repository at this point in the history
Fix inference bug
  • Loading branch information
tahirjmakhdoomi authored Feb 6, 2024
2 parents b6ca450 + c6cb1c8 commit 452f4f4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
16 changes: 13 additions & 3 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def transcribe(
augmentor: DictConfig = None,
verbose: bool = True,
logprobs: bool = False,
language_id: str = None,
) -> (List[str], Optional[List['Hypothesis']]):
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -133,6 +134,7 @@ def transcribe(
f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']"
)
if self.cur_decoder == "rnnt":
logging.info("Running with RNN-T decoder..")
return super().transcribe(
paths2audio_files=paths2audio_files,
batch_size=batch_size,
Expand All @@ -142,7 +144,10 @@ def transcribe(
channel_selector=channel_selector,
augmentor=augmentor,
verbose=verbose,
language_id = language_id
)

logging.info("Running with CTC decoder..")

if paths2audio_files is None or len(paths2audio_files) == 0:
return {}
Expand Down Expand Up @@ -194,13 +199,18 @@ def transcribe(
temporary_datalayer = self._setup_transcribe_dataloader(config)
logits_list = []
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
signal, signal_len, _, _ = test_batch
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [language_id] * len(signal)
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
input_signal=signal.to(device), input_signal_length=signal_len.to(device)
)

logits = self.ctc_decoder(encoder_output=encoded)
logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids)
best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor(
logits, encoded_len, return_hypotheses=return_hypotheses,
logits, encoded_len, return_hypotheses=return_hypotheses, lang_ids=language_ids,
)
logits = logits.cpu()

Expand Down
11 changes: 9 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def transcribe(
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
language_id: str = None,
) -> Tuple[List[str], Optional[List['Hypothesis']]]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -286,15 +287,21 @@ def transcribe(
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=(not verbose)):
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
signal, signal_len, _, _ = test_batch
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [language_id] * len(signal)
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
input_signal=signal.to(device), input_signal_length=signal_len.to(device)
)
best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_hypotheses=return_hypotheses,
partial_hypotheses=partial_hypothesis,
lang_ids=language_ids,
)

hypotheses += best_hyp
Expand Down

0 comments on commit 452f4f4

Please sign in to comment.