From 812f58c5f50c02589c08668d9afe6e4f8c6d0d74 Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Fri, 8 Nov 2024 20:28:21 -0800 Subject: [PATCH] Add whisper masking (#146) - Added masking in whisper encoder to ensure consistency in training and inference. - Simplified release_config.yaml to serve as an example configuration. --- ultravox/model/data_processing.py | 1 + ultravox/model/data_processing_test.py | 1 + ultravox/model/ultravox_model.py | 40 ++++++++++++++--- ultravox/model/ultravox_processing.py | 11 +++++ ultravox/training/configs/llama3_whisper.yaml | 6 --- .../training/configs/llama3_whisper_kd.yaml | 39 ---------------- ultravox/training/configs/release_config.yaml | 45 +------------------ 7 files changed, 50 insertions(+), 93 deletions(-) delete mode 100644 ultravox/training/configs/llama3_whisper.yaml delete mode 100644 ultravox/training/configs/llama3_whisper_kd.yaml diff --git a/ultravox/model/data_processing.py b/ultravox/model/data_processing.py index 5e196496..7a585f07 100644 --- a/ultravox/model/data_processing.py +++ b/ultravox/model/data_processing.py @@ -66,6 +66,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: inputs["audio_values"].squeeze_(0) inputs["audio_token_start_idx"].squeeze_(0) inputs["audio_token_len"].squeeze_(0) + inputs["audio_len"].squeeze_(0) # No need to shift the labels as the model does it internally labels = input_ids.clone() diff --git a/ultravox/model/data_processing_test.py b/ultravox/model/data_processing_test.py index 80601135..48f0dd9f 100644 --- a/ultravox/model/data_processing_test.py +++ b/ultravox/model/data_processing_test.py @@ -31,6 +31,7 @@ def fake_process(text, audio, return_tensors="pt", sampling_rate=16000): "audio_values": torch.tensor([[[0.1, 0.2, 0.3]]]), "audio_token_start_idx": torch.tensor([1]), "audio_token_len": torch.tensor([2]), + "audio_len": torch.tensor([10]), } diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index feab7aee..1a190df9 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -148,6 +148,7 @@ def forward( labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, + audio_len: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, # the alt_* fields are needed for KL divergence loss @@ -189,7 +190,7 @@ def forward( # B x A/3200 x D audio_tower_output = self.audio_tower.forward( - audio_values.to(self.audio_tower.dtype) + audio_values.to(self.audio_tower.dtype), audio_len=audio_len ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) @@ -235,6 +236,7 @@ def prepare_inputs_for_generation( audio_values: Optional[torch.FloatTensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, + audio_len: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -263,6 +265,7 @@ def prepare_inputs_for_generation( audio_token_start_idx - prefill_start_idx ) model_input["audio_token_len"] = audio_token_len + model_input["audio_len"] = audio_len return model_input @@ -508,7 +511,9 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor: return hidden_states -class ModifiedWhisperEncoder(whisper.WhisperEncoder): +class ModifiedWhisperEncoder( + whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin +): """ Encoder portion of OpenAI's Whisper model. @@ -527,7 +532,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder): def forward( self, input_features, - attention_mask=None, + audio_len=None, head_mask=None, output_attentions=None, output_hidden_states=None, @@ -570,6 +575,31 @@ def forward( encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + # Create attention mask based on audio lengths to mask out padding tokens + # For each sample in batch: + # - Convert raw audio length to feature length after convolutions + # - Create boolean mask that is True for valid positions and False for padding + # - Convert to extended attention mask format expected by transformer layers + # (1.0 for positions to attend to, large negative for positions to ignore) + # This masking ensures consistent behavior between training and inference + # by preventing the model from attending to padding tokens in both cases + attention_mask = None + if audio_len != None: + audio_feature_len = self._get_feat_extract_output_lengths(audio_len) + batch_size = hidden_states.shape[0] + max_seq_len = hidden_states.shape[1] + attention_mask = ( + torch.arange(max_seq_len, device=hidden_states.device)[None, :] + .expand(batch_size, -1) + .lt(audio_feature_len.view(batch_size, 1)) + ) + attention_mask = self.get_extended_attention_mask( + attention_mask, + None, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( @@ -593,14 +623,14 @@ def forward( layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, - None, + attention_mask, (head_mask[idx] if head_mask is not None else None), output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, - None, + attention_mask, layer_head_mask=( head_mask[idx] if head_mask is not None else None ), diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 3da068f6..5cab9e8c 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -154,6 +154,7 @@ def __call__( sampling_rate=sampling_rate, padding="longest", max_length=audio_len, + return_attention_mask=True, **kwargs, ) if "input_features" in x: @@ -161,6 +162,16 @@ def __call__( else: data["audio_values"] = x.input_values + # data["audio_len"] is the number of frames in the audio, used for creating attention masks in whisper encoder + if ( + self.audio_padding == "max_length" + ): # audio is padded to max length, so we rely on the attention mask to determine audio_len + data["audio_len"] = ( + x.attention_mask.sum(-1) - 1 + ) # Whisper attention mask includes an extra 1 at the end that needs to be subtracted + else: # audio is not padded, so we can directly use the audio length + data["audio_len"] = [torch.as_tensor(data["audio_values"]).shape[-1]] + if text is not None: assert isinstance( text, str diff --git a/ultravox/training/configs/llama3_whisper.yaml b/ultravox/training/configs/llama3_whisper.yaml deleted file mode 100644 index f8e77bb3..00000000 --- a/ultravox/training/configs/llama3_whisper.yaml +++ /dev/null @@ -1,6 +0,0 @@ -# SLM with ultravox & llama3 -exp_name: "llama3_whisper_s" - -# Make sure to accept the license agreement on huggingface hub -text_model: "meta-llama/Meta-Llama-3-8B-Instruct" -audio_model: "openai/whisper-small" \ No newline at end of file diff --git a/ultravox/training/configs/llama3_whisper_kd.yaml b/ultravox/training/configs/llama3_whisper_kd.yaml deleted file mode 100644 index d951f02d..00000000 --- a/ultravox/training/configs/llama3_whisper_kd.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# SLM with ultravox & llama3, trained wtih knowledge distillation. -exp_name: "llama3_whisper_s" - -# Make sure to accept the license agreement on huggingface hub -text_model: "meta-llama/Meta-Llama-3-8B-Instruct" -audio_model: "openai/whisper-small" - - -loss_config: - # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" - loss_function: "KL_Divergence" - -# Temporarily remove heysquad_human from val_sets as it causes the training to fail. -val_sets: ["anyinstruct", "soda", "peoplespeech"] - -batch_size: 4 -max_steps: 1000 - -data_sets: [] -data_dicts: - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" - - "train.360" - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - weight: 2 - num_samples: 100_000 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - weight: 1 - num_samples: 100_000 diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 329f9c89..0ec26331 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -1,6 +1,4 @@ # SLM with ultravox & llama3.1, trained wtih knowledge distillation. -exp_name: "ultravox-v0_4" - # Make sure to accept the license agreement on huggingface hub text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" audio_model: "openai/whisper-medium" @@ -12,51 +10,12 @@ loss_config: train_sets: - name: librispeech-clean-continuation - name: librispeech-other-continuation - - name: peoplespeech-clean-continuation - weight: 8 - name: commonvoice-en-continuation - weight: 8 - - name: commonvoice-ar-continuation - weight: 0.2 - - name: commonvoice-de-continuation - weight: 4 - - name: commonvoice-es-continuation - weight: 3 - - name: commonvoice-fr-continuation - weight: 4 - - name: commonvoice-it-continuation - weight: 1.2 - - name: commonvoice-ja-continuation - weight: 0.1 - - name: commonvoice-pt-continuation - weight: 0.2 - - name: commonvoice-ru-continuation - weight: 0.2 - - name: librispeech-clean-transcription - - name: librispeech-other-transcription - - name: peoplespeech-clean-transcription - weight: 0.8 - - name: commonvoice-en-transcription - weight: 0.8 - - name: commonvoice-ar-transcription - weight: 0.02 - - name: commonvoice-de-transcription - weight: 0.4 - - name: commonvoice-es-transcription - weight: 0.3 - - name: commonvoice-fr-transcription - weight: 0.4 - - name: commonvoice-it-transcription - weight: 0.12 - - name: commonvoice-ja-transcription - weight: 0.01 - - name: commonvoice-pt-transcription - weight: 0.02 - - name: commonvoice-ru-transcription - weight: 0.02 # Temporarily remove heysquad_human from val_sets as it causes the training to fail. val_sets: + - name: covost2-en-de + - name: covost2-zh-en - name: peoplespeech-clean-transcription batch_size: 24