Skip to content

Commit

Permalink
Add language id as input
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed Jul 23, 2024
1 parent fce1c09 commit 2734475
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 8 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ RUN /usr/bin/test -n "$NEMO_VERSION" && \
/bin/echo "export BASE_IMAGE=${BASE_IMAGE}" >> /root/.bashrc

# Install NeMo
RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]"
# FIXME by Kaushal: .[all] did not work
# RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]"
RUN pip install ".[all]"

# Check install
RUN python -c "import nemo.collections.nlp as nemo_nlp" && \
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.joint.offset_token_ids_by_token_id = self.offset_token_ids_by_token_id
self.ctc_decoder.language_masks = self.language_masks

# Create language embeddings
language_list = self.tokenizer.tokenizers_dict.keys()
self.encoder.add_language_embeddings(language_list)

# Setup wer object
self.wer = WER(
decoding=self.decoding,
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
return super()._transcribe_forward(batch, trcfg)

# CTC Path
encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1])
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [trcfg.language_id] * len(batch[0])
encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1], language_ids=language_ids)
logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids)
output = dict(logits=logits, encoded_len=encoded_len, language_ids=language_ids)

Expand Down
13 changes: 10 additions & 3 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin
from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType
from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType, StringType
from nemo.utils import logging


Expand Down Expand Up @@ -586,6 +586,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
'language_ids': [NeuralType(('B'), StringType(), optional=True)], # CTEMO
}

@property
Expand Down Expand Up @@ -644,7 +645,7 @@ def forward(
if self.spec_augmentation is not None and self.training:
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length, language_ids=language_ids)
return encoded, encoded_len

# PTL-specific methods
Expand Down Expand Up @@ -873,7 +874,13 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
""" Transcription related methods """

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1])
# CTC Path
if "multisoftmax" not in self.cfg.decoder:
language_ids = None
else:
language_ids = [trcfg.language_id] * len(batch[0])
encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1], language_ids=language_ids)
# encoded, encoded_len = self.forward(input_signal=batch[0], input_signal_length=batch[1])
output = dict(encoded=encoded, encoded_len=encoded_len)
return output

Expand Down
23 changes: 20 additions & 3 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin, adapter_mixins
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType
from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType, StringType
from nemo.utils import logging

__all__ = ['ConformerEncoder']
Expand Down Expand Up @@ -200,6 +200,7 @@ def input_types(self):
"cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
'language_ids': [NeuralType(('B'), StringType(), optional=True)], # CTEMO
}
)

Expand All @@ -213,6 +214,7 @@ def input_types_for_export(self):
"cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
'language_ids': [NeuralType(('B'), StringType(), optional=True)], # CTEMO
}
)

Expand Down Expand Up @@ -445,6 +447,14 @@ def __init__(
# will be set in self.forward() if defined in AccessMixin config
self.interctc_capture_at_layers = None

# Add language id embedding to network
self.language_embeddings = None

def add_language_embeddings(self, language_list):
self.language_to_idx = {language: idx for idx, language in enumerate(language_list)}
num_languages = len(language_list)
self.language_embeddings = nn.Embedding(num_languages, self.d_model, max_norm=True)

def forward_for_export(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
Expand Down Expand Up @@ -493,18 +503,19 @@ def streaming_post_process(self, rets, keep_all_outputs=True):

@typecheck()
def forward(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None, language_ids=None,
):
return self.forward_internal(
audio_signal,
length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
language_ids=language_ids,
)

def forward_internal(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None, language_ids=None,
):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)

Expand Down Expand Up @@ -549,6 +560,12 @@ def forward_internal(
offset = None

audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
# breakpoint()
if language_ids is not None:
language_ints = torch.tensor([self.language_to_idx[language] for language in language_ids], device=audio_signal.device)
language_inputs = self.language_embeddings(language_ints).unsqueeze(1).repeat(1, 32, 1)
audio_signal = torch.cat((language_inputs, audio_signal), 1)
# breakpoint()

# Create the self-attention and padding masks
pad_mask, att_mask = self._create_masks(
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel,
logging.info(f"Restoring model : {imported_class.__name__}")
asr_model = imported_class.restore_from(
restore_path=cfg.model_path, map_location=map_location,
# FIXME: Kaushal, added for debugging multi-langid
strict=False,
) # type: ASRModel
model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
else:
Expand Down

0 comments on commit 2734475

Please sign in to comment.