Skip to content

Commit

Permalink
fixed multisoftmax for single language
Browse files Browse the repository at this point in the history
  • Loading branch information
tahirjmakhdoomi committed Feb 21, 2024
1 parent 29b9f7b commit 438aa07
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.language_masks[language] = [(token_language == language) for _, token_language in self.tokenizer.langs_by_token_id.items()]
self.language_masks[language].append(True) # Insert blank token
self.ctc_loss = CTCLoss(
num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()),
num_classes=(self.ctc_decoder._num_classes-1 )// len(self.tokenizer.tokenizers_dict.keys()),
zero_infinity=True,
reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"),
)
# Setup RNNT Loss
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))
self.loss = RNNTLoss(
num_classes=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()),
num_classes=(self.ctc_decoder._num_classes-1) // len(self.tokenizer.tokenizers_dict.keys()),
loss_name=loss_name,
loss_kwargs=loss_kwargs,
reduction=self.cfg.get("rnnt_reduction", "mean_batch"),
)
# Setup decoding object
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1) // len(self.tokenizer.tokenizers_dict.keys())
)

self.decoder.language_masks = self.language_masks
Expand Down Expand Up @@ -148,10 +148,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = ctc_decoding_cfg
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder:
breakpoint()
if ctc_decoding_cfg.strategy == 'pyctcdecode':
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang='any')
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1)//len(self.tokenizer.tokenizers_dict.keys()),lang='any')
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1)//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

Expand Down Expand Up @@ -343,7 +344,7 @@ def change_vocabulary(

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1) // len(self.tokenizer.tokenizers_dict.keys())
)
else:
self.decoding = RNNTBPEDecoding(
Expand Down Expand Up @@ -414,7 +415,7 @@ def change_vocabulary(
ctc_decoding_cfg = OmegaConf.merge(ctc_decoding_cls, ctc_decoding_cfg)

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1)//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

Expand Down Expand Up @@ -457,7 +458,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1) // len(self.tokenizer.tokenizers_dict.keys())
)
else:
self.decoding = RNNTBPEDecoding(
Expand Down Expand Up @@ -500,7 +501,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer, blank_id=(self.ctc_decoder._num_classes-1)//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)

Expand Down

0 comments on commit 438aa07

Please sign in to comment.