From c6cb1c859fd61e0d80a1f8abf171f19ef886e2e8 Mon Sep 17 00:00:00 2001 From: kaushal-py Date: Mon, 5 Feb 2024 17:08:55 +0530 Subject: [PATCH] Fix inference bug --- .../asr/models/hybrid_rnnt_ctc_bpe_models.py | 32 ++++++++++++++----- .../asr/models/hybrid_rnnt_ctc_models.py | 16 ++++++++-- nemo/collections/asr/models/rnnt_models.py | 11 +++++-- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index feafcbc30..fb72bef41 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -338,9 +338,14 @@ def change_vocabulary( decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, - ) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + ) + else: + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) self.wer = RNNTBPEWER( decoding=self.decoding, @@ -405,7 +410,10 @@ def change_vocabulary( ctc_decoding_cls = OmegaConf.create(OmegaConf.to_container(ctc_decoding_cls)) ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg) - self.ctc_decoding = CTCBPEDecoding(decoding_cfg=ctc_decoding_cfg, tokenizer=self.tokenizer) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + else: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) self.ctc_wer = WERBPE( decoding=self.ctc_decoding, @@ -444,9 +452,14 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, - ) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()) + ) + else: + self.decoding = RNNTBPEDecoding( + decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + ) self.wer = RNNTBPEWER( decoding=self.decoding, @@ -483,7 +496,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) - self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer) + if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + else: + self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer) self.ctc_wer = WERBPE( decoding=self.ctc_decoding, diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 6f6bfcb77..47efcd122 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -104,6 +104,7 @@ def transcribe( augmentor: DictConfig = None, verbose: bool = True, logprobs: bool = False, + language_id: str = None, ) -> (List[str], Optional[List['Hypothesis']]): """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -133,6 +134,7 @@ def transcribe( f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']" ) if self.cur_decoder == "rnnt": + logging.info("Running with RNN-T decoder..") return super().transcribe( paths2audio_files=paths2audio_files, batch_size=batch_size, @@ -142,7 +144,10 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + language_id = language_id ) + + logging.info("Running with CTC decoder..") if paths2audio_files is None or len(paths2audio_files) == 0: return {} @@ -194,13 +199,18 @@ def transcribe( temporary_datalayer = self._setup_transcribe_dataloader(config) logits_list = [] for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): + signal, signal_len, _, _ = test_batch + if "multisoftmax" not in self.cfg.decoder: + language_ids = None + else: + language_ids = [language_id] * len(signal) encoded, encoded_len = self.forward( - input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + input_signal=signal.to(device), input_signal_length=signal_len.to(device) ) - logits = self.ctc_decoder(encoder_output=encoded) + logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids) best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor( - logits, encoded_len, return_hypotheses=return_hypotheses, + logits, encoded_len, return_hypotheses=return_hypotheses, lang_ids=language_ids, ) logits = logits.cpu() diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 30d66df5e..7edbd56b0 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -218,6 +218,7 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + language_id: str = None, ) -> Tuple[List[str], Optional[List['Hypothesis']]]: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -286,15 +287,21 @@ def transcribe( config['augmentor'] = augmentor temporary_datalayer = self._setup_transcribe_dataloader(config) - for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=(not verbose)): + for test_batch in tqdm(temporary_datalayer, desc="Transcribing", disable=not verbose): + signal, signal_len, _, _ = test_batch + if "multisoftmax" not in self.cfg.decoder: + language_ids = None + else: + language_ids = [language_id] * len(signal) encoded, encoded_len = self.forward( - input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + input_signal=signal.to(device), input_signal_length=signal_len.to(device) ) best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor( encoded, encoded_len, return_hypotheses=return_hypotheses, partial_hypotheses=partial_hypothesis, + lang_ids=language_ids, ) hypotheses += best_hyp