Skip to content

Commit

Permalink
Fix inference bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed Feb 5, 2024
1 parent 3bddb03 commit c6cb1c8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
32 changes: 24 additions & 8 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,14 @@ def change_vocabulary(
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
)
else:
self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)

self.wer = RNNTBPEWER(
decoding=self.decoding,
Expand Down Expand Up @@ -405,7 +410,10 @@ def change_vocabulary(
ctc_decoding_cls = OmegaConf.create(OmegaConf.to_container(ctc_decoding_cls))
ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg)

self.ctc_decoding = CTCBPEDecoding(decoding_cfg=ctc_decoding_cfg, tokenizer=self.tokenizer)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

self.ctc_wer = WERBPE(
decoding=self.ctc_decoding,
Expand Down Expand Up @@ -444,9 +452,14 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
)
else:
self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)

self.wer = RNNTBPEWER(
decoding=self.decoding,
Expand Down Expand Up @@ -483,7 +496,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer)
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

self.ctc_wer = WERBPE(
decoding=self.ctc_decoding,
Expand Down
16 changes: 13 additions & 3 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def transcribe(
augmentor: DictConfig = None,
verbose: bool = True,
logprobs: bool = False,
language_id: str = None,
) -> (List[str], Optional[List['Hypothesis']]):
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -133,6 +134,7 @@ def transcribe(
f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']"
)
if self.cur_decoder == "rnnt":
logging.info("Running with RNN-T decoder..")
return super().transcribe(
paths2audio_files=paths2audio_files,
batch_size=batch_size,
Expand All @@ -142,7 +144,10 @@ def transcribe(
channel_selector=channel_selector,
augmentor=augmentor,
verbose=verbose,
language_id = language_id
)

logging.info("Running with CTC decoder..")

if paths2audio_files is None or len(paths2audio_files) == 0:
return {}
Expand Down Expand Up @@ -194,13 +199,18 @@ def transcribe(
temporary_datalayer = self._setup_transcribe_dataloader(config)
logits_list = []
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
signal, signal_len, _, _ = test_batch
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [language_id] * len(signal)
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
input_signal=signal.to(device), input_signal_length=signal_len.to(device)
)

logits = self.ctc_decoder(encoder_output=encoded)
logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids)
best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor(
logits, encoded_len, return_hypotheses=return_hypotheses,
logits, encoded_len, return_hypotheses=return_hypotheses, lang_ids=language_ids,
)
logits = logits.cpu()

Expand Down
11 changes: 9 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def transcribe(
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
language_id: str = None,
) -> Tuple[List[str], Optional[List['Hypothesis']]]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -286,15 +287,21 @@ def transcribe(
config['augmentor'] = augmentor

temporary_datalayer = self._setup_transcribe_dataloader(config)
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=(not verbose)):
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose):
signal, signal_len, _, _ = test_batch
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [language_id] * len(signal)
encoded, encoded_len = self.forward(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
input_signal=signal.to(device), input_signal_length=signal_len.to(device)
)
best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_hypotheses=return_hypotheses,
partial_hypotheses=partial_hypothesis,
lang_ids=language_ids,
)

hypotheses += best_hyp
Expand Down

0 comments on commit c6cb1c8

Please sign in to comment.