Skip to content

Commit

Permalink
Add support for HMS inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed Apr 22, 2024
1 parent 3fee261 commit 0f8bc67
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 20 deletions.
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def transcribe(
channel_selector: Optional[ChannelSelectorType] = None,
augmentor: DictConfig = None,
verbose: bool = True,
logprobs: bool = False,
language_id: str = None,
override_config: Optional[TranscribeConfig] = None,
) -> TranscriptionReturnType:
"""
Expand Down Expand Up @@ -159,6 +161,7 @@ def transcribe(
channel_selector=channel_selector,
augmentor=augmentor,
verbose=verbose,
language_id=language_id,
override_config=override_config,
)

Expand Down
11 changes: 6 additions & 5 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def change_vocabulary(

logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.")

def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None):
def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None, lang_id: str=None):
"""
Changes decoding strategy used during RNNT decoding process.
Args:
Expand All @@ -458,7 +458,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: #CTEMO
self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys())
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes // len(self.tokenizer.tokenizers_dict.keys()),
# lang_id=lang_id
)
else:
self.decoding = RNNTBPEDecoding(
Expand Down Expand Up @@ -488,7 +489,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type

self.cur_decoder = "rnnt"
logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

elif decoder_type == 'ctc':
if not hasattr(self, 'ctc_decoding'):
raise ValueError("The model does not have the ctc_decoding module and does not support ctc decoding.")
Expand All @@ -503,9 +504,9 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in self.cfg.decoder: #CTEMO
self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer, blank_id=self.ctc_decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()), lang_id=lang_id)
else:
self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer)
self.ctc_decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer, lang_id=lang_id)

self.ctc_wer = WER(
decoding=self.ctc_decoding,
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [language_id] * len(batch[0])
language_ids = [trcfg.language_id] * len(batch[0])
logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids)
output = dict(logits=logits, encoded_len=encoded_len, language_ids=language_ids)

Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def transcribe(
verbose=verbose,
override_config=override_config,
logprobs=logprobs,
language_id=language_id,
# Additional arguments
partial_hypothesis=partial_hypothesis,
)
Expand Down Expand Up @@ -880,12 +881,18 @@ def _transcribe_output_processing(
) -> Tuple[List['Hypothesis'], List['Hypothesis']]:
encoded = outputs.pop('encoded')
encoded_len = outputs.pop('encoded_len')

if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [trcfg.language_id] * len(encoded)

best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_hypotheses=trcfg.return_hypotheses,
partial_hypotheses=trcfg.partial_hypothesis,
lang_ids=language_ids,
)

# cleanup memory
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,7 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor, language_ids=
# Forward adapter modules on joint hidden
if self.is_adapter_available():
inp = self.forward_enabled_adapters(inp)

if language_ids is not None: #CTEMO

# Do partial forward of joint net (skipping the final linear)
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TranscribeConfig:
augmentor: Optional[DictConfig] = None
verbose: bool = True
logprobs: bool = False
language_id: str = None

# Utility
partial_hypothesis: Optional[List[Any]] = None
Expand Down Expand Up @@ -196,6 +197,7 @@ def transcribe(
verbose: bool = True,
override_config: Optional[TranscribeConfig] = None,
logprobs: bool = False,
language_id: str = None,
**config_kwargs,
) -> GenericTranscriptionType:
"""
Expand Down Expand Up @@ -245,6 +247,7 @@ def transcribe(
augmentor=augmentor,
verbose=verbose,
logprobs=logprobs,
language_id=language_id,
**config_kwargs,
)
else:
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,7 @@ class CTCBPEDecoding(AbstractCTCDecoding):
tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec.
"""

def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): #CTEMO
def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None, lang_id: str = None): #CTEMO
if blank_id is None:
blank_id = tokenizer.tokenizer.vocab_size
self.tokenizer = tokenizer
Expand All @@ -1223,7 +1223,12 @@ def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): #CT
if hasattr(self.tokenizer.tokenizer, 'get_vocab'):
vocab_dict = self.tokenizer.tokenizer.get_vocab()
if isinstance(self.tokenizer.tokenizer, DummyTokenizer): # AggregateTokenizer.DummyTokenizer
vocab = vocab_dict
if lang_id is not None:
tokenizer = self.tokenizer.tokenizers_dict[lang_id]
vocab_dict = tokenizer.tokenizer.get_vocab()
vocab = list(vocab_dict.keys())
else:
vocab = vocab_dict
else:
vocab = list(vocab_dict.keys())
self.decoding.set_vocabulary(vocab)
Expand Down
13 changes: 1 addition & 12 deletions nemo/collections/common/tokenizers/multilingual_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,12 @@
import numpy as np

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.common.tokenizers.aggregate_tokenizer import DummyTokenizer
from nemo.utils import logging

__all__ = ['MultilingualTokenizer']


class DummyTokenizer:
def __init__(self, vocab):
self.vocab = vocab
self.vocab_size = len(vocab)

# minimum compatibility
# since all the monolingual tokenizers have a vocab
# additional methods could be added here
def get_vocab(self):
return self.vocab


class MultilingualTokenizer(TokenizerSpec):
'''
MultilingualTokenizer, allowing one to combine multiple regular monolongual tokenizers into one tokenizer.
Expand Down

0 comments on commit 0f8bc67

Please sign in to comment.