diff --git a/Dockerfile b/Dockerfile index 970c34a69..4d298a7fd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -150,7 +150,9 @@ RUN /usr/bin/test -n "$NEMO_VERSION" && \ /bin/echo "export BASE_IMAGE=${BASE_IMAGE}" >> /root/.bashrc # Install NeMo -RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" +# FIXME by Kaushal: .[all] did not work +# RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" +RUN pip install ".[all]" # Check install RUN python -c "import nemo.collections.nlp as nemo_nlp" && \ diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 16f8ac8dc..895b93d0a 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -133,6 +133,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.joint.offset_token_ids_by_token_id = self.offset_token_ids_by_token_id self.ctc_decoder.language_masks = self.language_masks + # Create language embeddings + language_list = self.tokenizer.tokenizers_dict.keys() + self.encoder.add_language_embeddings(language_list) + # Setup wer object self.wer = WER( decoding=self.decoding, diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 72c821f99..5ab8b7471 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -168,11 +168,11 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): return super()._transcribe_forward(batch, trcfg) # CTC Path - encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1]) if "multisoftmax" not in self.cfg.decoder: language_ids = None else: language_ids = [trcfg.language_id] * len(batch[0]) + encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1], language_ids=language_ids) logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids) output = dict(logits=logits, encoded_len=encoded_len, language_ids=language_ids) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 380f8297b..65726dc18 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -43,7 +43,7 @@ from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.classes.mixins import AccessMixin -from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType +from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType, StringType from nemo.utils import logging @@ -586,6 +586,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], # CTEMO } @property @@ -644,7 +645,7 @@ def forward( if self.spec_augmentation is not None and self.training: processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) - encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length, language_ids=language_ids) return encoded, encoded_len # PTL-specific methods @@ -873,7 +874,13 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): """ Transcription related methods """ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): - encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1]) + # CTC Path + if "multisoftmax" not in self.cfg.decoder: + language_ids = None + else: + language_ids = [trcfg.language_id] * len(batch[0]) + encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1], language_ids=language_ids) + # encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1]) output = dict(encoded=encoded, encoded_len=encoded_len) return output diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index b9642b3ea..85017295f 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -46,7 +46,7 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.mixins import AccessMixin, adapter_mixins from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType +from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType, StringType from nemo.utils import logging __all__ = ['ConformerEncoder'] @@ -200,6 +200,7 @@ def input_types(self): "cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True), "cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True), "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], # CTEMO } ) @@ -213,6 +214,7 @@ def input_types_for_export(self): "cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True), "cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True), "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True), + 'language_ids': [NeuralType(('B'), StringType(), optional=True)], # CTEMO } ) @@ -445,6 +447,14 @@ def __init__( # will be set in self.forward() if defined in AccessMixin config self.interctc_capture_at_layers = None + # Add language id embedding to network + self.language_embeddings = None + + def add_language_embeddings(self, language_list): + self.language_to_idx = {language: idx for idx, language in enumerate(language_list)} + num_languages = len(language_list) + self.language_embeddings = nn.Embedding(num_languages, self.d_model, max_norm=True) + def forward_for_export( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): @@ -493,7 +503,7 @@ def streaming_post_process(self, rets, keep_all_outputs=True): @typecheck() def forward( - self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None, language_ids=None, ): return self.forward_internal( audio_signal, @@ -501,10 +511,11 @@ def forward( cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, cache_last_channel_len=cache_last_channel_len, + language_ids=language_ids, ) def forward_internal( - self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None, language_ids=None, ): self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) @@ -549,6 +560,12 @@ def forward_internal( offset = None audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) + # breakpoint() + if language_ids is not None: + language_ints = torch.tensor([self.language_to_idx[language] for language in language_ids], device=audio_signal.device) + language_inputs = self.language_embeddings(language_ints).unsqueeze(1).repeat(1, 32, 1) + audio_signal = torch.cat((language_inputs, audio_signal), 1) + # breakpoint() # Create the self-attention and padding masks pad_mask, att_mask = self._create_masks( diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 980500e9e..ed385f305 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -250,6 +250,8 @@ def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel, logging.info(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from( restore_path=cfg.model_path, map_location=map_location, + # FIXME: Kaushal, added for debugging multi-langid + strict=False, ) # type: ASRModel model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] else: