From f769ac8bbd55448a024f267e51eb55f8512e0af8 Mon Sep 17 00:00:00 2001 From: Shucong Zhang/Embedded AI /SRUK/Engineer/Samsung Electronics Date: Fri, 30 Aug 2024 12:45:19 +0100 Subject: [PATCH 1/2] layernom flag and ssl and conformer sb1.0 --- .../MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml | 195 ++++++ .../MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml | 195 ++++++ benchmarks/MP3S/CommonVoice/LSTM/train.py | 305 ++++++++++ .../MP3S/LibriSpeech/LSTM/hparams/ssl.yaml | 188 ++++++ benchmarks/MP3S/LibriSpeech/LSTM/train.py | 344 +++++++++++ .../LibriSpeech/contextnet/hparam/ssl.yaml | 184 ++++++ .../MP3S/LibriSpeech/contextnet/train.py | 342 +++++++++++ .../MP3S/SLURP/LSTM_linear/hparams/ssl.yaml | 201 +++++++ benchmarks/MP3S/SLURP/LSTM_linear/train.py | 318 ++++++++++ .../VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml | 195 ++++++ benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py | 468 ++++++++++++++ .../hparams/summarymixing_wav2vec2.yaml | 180 ++++++ .../wav2vec2/make_librilight_csv.py | 90 +++ .../wav2vec2/train_sb_wav2vec2_mel.py | 431 +++++++++++++ .../conformer_summarymixing_transducer.yaml | 1 + .../hparams/conformer_summarymixing.yaml | 351 +++++++++++ .../conformer_summarymixing_transducer.yaml | 1 + .../lobes/models/transformer/Conformer.py | 67 ++- .../lobes/models/transformer/Transformer.py | 4 + .../models/transformer/TransformerASR.py | 4 + speechbrain/lobes/models/wav2vec.py | 569 ++++++++++++++++++ speechbrain/nnet/summary_mixing.py | 14 + 22 files changed, 4634 insertions(+), 13 deletions(-) create mode 100644 benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml create mode 100644 benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml create mode 100644 benchmarks/MP3S/CommonVoice/LSTM/train.py create mode 100644 benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml create mode 100644 benchmarks/MP3S/LibriSpeech/LSTM/train.py create mode 100644 benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml create mode 100644 benchmarks/MP3S/LibriSpeech/contextnet/train.py create mode 100644 benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml create mode 100644 benchmarks/MP3S/SLURP/LSTM_linear/train.py create mode 100644 benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml create mode 100644 benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py create mode 100644 recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml create mode 100644 recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py create mode 100644 recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py create mode 100644 recipes/LibriSpeech/ASR/transformer/hparams/conformer_summarymixing.yaml create mode 100644 speechbrain/lobes/models/wav2vec.py diff --git a/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml new file mode 100644 index 0000000..5b808ae --- /dev/null +++ b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml @@ -0,0 +1,195 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on CommonVoice cy with a LSTM downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/CommonVoice/LSTM +# Copy this file under benchmarks/MP3S/CommonVoice/LSTM/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +language: cy # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english +output_folder: !ref results/CommonVoice// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +avoid_if_longer_than: 10.0 + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0004 +lr_weights: 0.02 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +token_type: bpe # ["unigram", "bpe", "char"] +character_coverage: 1.0 + + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref +dataloader_options: + batch_size: !ref + num_workers: 4 +test_dataloader_options: + batch_size: !ref + num_workers: 4 + + +valid_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 100 # BPE size, index(blank/eos/bos) = 0 + + +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: 2 + bidirectional: True + dropout: 0.2 + hidden_size: 1024 + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml new file mode 100644 index 0000000..04a439c --- /dev/null +++ b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml @@ -0,0 +1,195 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on CommonVoice eu with a LSTM downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/CommonVoice/LSTM +# Copy this file under benchmarks/MP3S/CommonVoice/LSTM/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +language: eu # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english +output_folder: !ref results/CommonVoice// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +avoid_if_longer_than: 10.0 + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0005 +lr_weights: 0.025 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +token_type: bpe # ["unigram", "bpe", "char"] +character_coverage: 1.0 + + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref +dataloader_options: + batch_size: !ref + num_workers: 4 +test_dataloader_options: + batch_size: !ref + num_workers: 4 + + +valid_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 100 # BPE size, index(blank/eos/bos) = 0 + + +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: 2 + bidirectional: True + dropout: 0.2 + hidden_size: 1024 + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/CommonVoice/LSTM/train.py b/benchmarks/MP3S/CommonVoice/LSTM/train.py new file mode 100644 index 0000000..ddb3a39 --- /dev/null +++ b/benchmarks/MP3S/CommonVoice/LSTM/train.py @@ -0,0 +1,305 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Script for training an ASR model evaluating an SSL representation +model on one language from the CommonVoice dataset. A SentencePiece tokenizer +with number of tokens equal to is learned in a first phase +on the considered language. +""" + +import sys +import torch +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +import torchaudio +from speechbrain.tokenizers.SentencePiece import SentencePiece +from speechbrain.utils.data_utils import undo_padding + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + # Forward pass + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + y = self.modules.enc(feats) + y = y[0] # As it is an RNN output + # Compute outputs + p_tokens = None + logits = self.modules.ctc_lin(y) + p_ctc = self.hparams.log_softmax(logits) + if stage != sb.Stage.TRAIN: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + loss = loss_ctc + + if stage != sb.Stage.TRAIN: + # Decode token terms to words + predicted_words = self.tokenizer( + predicted_tokens, task="decode_from_list" + ) + + # Convert indices to words + target_words = undo_padding(tokens, tokens_lens) + target_words = self.tokenizer(target_words, task="decode_from_list") + + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + predictions = self.compute_forward(batch, stage=stage) + with torch.no_grad(): + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss.detach() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + stage_stats["loss"] + ) + old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( + stage_stats["loss"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr_model + ) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_weights + ) + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +# Define custom data procedure +def dataio_prepare(hparams, tokenizer): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + + # 1. Define datasets + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted( + sort_key="duration", + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_options"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", + reverse=True, + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_options"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + # We also sort the validation data so it is faster to validate + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], replacements={"data_root": data_folder}, + ) + + # We also sort the validation data so it is faster to validate + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + info = torchaudio.info(wav) + sig = sb.dataio.dataio.read_audio(wav) + resampled = torchaudio.transforms.Resample( + info.sample_rate, hparams["sample_rate"], + )(sig) + return resampled + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides("tokens_list", "tokens") + def text_pipeline(wrd): + tokens_list = tokenizer.sp.encode_as_ids(wrd) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "tokens"], + ) + return train_data, valid_data, test_data + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset preparation + from common_voice_prepare import prepare_common_voice # noqa + + # multi-gpu (ddp) save data preparation + # Due to DDP, we do the preparation ONLY on the main python process + run_on_main( + prepare_common_voice, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "train_tsv_file": hparams["train_tsv_file"], + "dev_tsv_file": hparams["dev_tsv_file"], + "test_tsv_file": hparams["test_tsv_file"], + "accented_letters": hparams["accented_letters"], + "language": hparams["language"], + "skip_prep": hparams["skip_prep"], + }, + ) + + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["save_folder"], + vocab_size=hparams["output_neurons"], + annotation_train=hparams["train_csv"], + annotation_read="wrd", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + ) + + # Create the datasets objects as well as tokenization and encoding :-D + train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + asr_brain.tokenizer = tokenizer + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Testing + asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt" + asr_brain.evaluate( + test_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_options"], + ) diff --git a/benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml b/benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml new file mode 100644 index 0000000..f3dbbae --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml @@ -0,0 +1,188 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on LibriSpeech with a LSTM downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/LibriSpeech/LSTM +# Copy this file under benchmarks/MP3S/LibriSpeech/LSTM/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/LibriSpeech/ +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech +# noise/ris dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100"] +dev_splits: ["dev-clean"] +test_splits: ["test-clean", "test-other"] + +skip_prep: False +ckpt_interval_minutes: 25 # save checkpoint every N min +train_csv: !ref /train-clean-100.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0002 +lr_weights: 0.01 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +language_modelling: False +#ngram_lm_path: !PLACEHOLDER #path/to/4-gram.arpa + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 + +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: 2 + bidirectional: True + dropout: 0.2 + hidden_size: 1024 + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/LibriSpeech/LSTM/train.py b/benchmarks/MP3S/LibriSpeech/LSTM/train.py new file mode 100644 index 0000000..d763448 --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/LSTM/train.py @@ -0,0 +1,344 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for training an SSL-based ctc ASR system with librispeech. + Decoding is performed with ctc greedy or LM-rescored decoder. +""" + +import os +import sys +import torch +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + # Forward pass + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + y = self.modules.enc(feats) + y = y[0] # As it is an RNN output + # Compute outputs + p_tokens = None + logits = self.modules.ctc_lin(y) + p_ctc = self.hparams.log_softmax(logits) + if stage != sb.Stage.TRAIN: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + loss = loss_ctc + + if stage == sb.Stage.VALID: + # Decode token terms to words + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + elif stage == sb.Stage.TEST: + if self.hparams.language_modelling: + predicted_words = [] + for logs in p_ctc: + text = decoder.decode(logs.detach().cpu().numpy()) + predicted_words.append(text.split(" ")) + else: + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + predictions = self.compute_forward(batch, stage=stage) + with torch.no_grad(): + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss.detach() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + stage_stats["loss"] + ) + old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( + stage_stats["loss"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr_model + ) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_weights + ) + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + # test is separate + test_datasets = {} + for csv_file in hparams["test_csv"]: + name = Path(csv_file).stem + test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_file, replacements={"data_root": data_folder} + ) + test_datasets[name] = test_datasets[name].filtered_sorted( + sort_key="duration" + ) + + datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + label_encoder = sb.dataio.encoder.CTCTextEncoder() + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides( + "wrd", "char_list", "tokens_list", "tokens" + ) + def text_pipeline(wrd): + yield wrd + char_list = list(wrd) + yield char_list + tokens_list = label_encoder.encode_sequence(char_list) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + special_labels = { + "blank_label": hparams["blank_index"], + "unk_label": hparams["unk_index"], + } + label_encoder.load_or_create( + path=lab_enc_file, + from_didatasets=[train_data], + output_key="char_list", + special_labels=special_labels, + sequence_input=True, + ) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "char_list", "tokens"], + ) + return train_data, valid_data, test_datasets, label_encoder + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from librispeech_prepare import prepare_librispeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_librispeech, + kwargs={ + "data_folder": hparams["data_folder"], + "tr_splits": hparams["train_splits"], + "dev_splits": hparams["dev_splits"], + "te_splits": hparams["test_splits"], + "save_folder": hparams["output_folder"], + "merge_lst": hparams["train_splits"], + "merge_name": "train.csv", + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + train_data, valid_data, test_datasets, label_encoder = dataio_prepare( + hparams + ) + # Loading the labels for the LM decoding and the CTC decoder + if "language_modelling" in hparams: + + if hparams["language_modelling"]: + from pyctcdecode import build_ctcdecoder + + ind2lab = label_encoder.ind2lab + labels = [ind2lab[x] for x in range(len(ind2lab))] + labels = [""] + labels[ + 1: + ] # Replace the token with a blank character, needed for PyCTCdecode + decoder = build_ctcdecoder( + labels, + kenlm_model_path=hparams[ + "ngram_lm_path" + ], # either .arpa or .bin file + alpha=0.5, # tuned on a val set + beta=1.0, # tuned on a val set + ) + else: + hparams["language_modelling"] = False + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Loading the SSL model + # We dynamicaly add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + asr_brain.tokenizer = label_encoder + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Testing + for k in test_datasets.keys(): # keys are test_clean, test_other etc + asr_brain.hparams.wer_file = os.path.join( + hparams["output_folder"], "wer_{}.txt".format(k) + ) + asr_brain.evaluate( + test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"] + ) diff --git a/benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml b/benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml new file mode 100644 index 0000000..6e29b74 --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml @@ -0,0 +1,184 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on LibriSpeech with a ContextNet downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/LibriSpeech/contextnet +# Copy this file under benchmarks/MP3S/LibriSpeech/contextnet/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/LibriSpeech/ +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech +# noise/ris dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100"] +dev_splits: ["dev-clean"] +test_splits: ["test-clean", "test-other"] +skip_prep: False +ckpt_interval_minutes: 25 # save checkpoint every N min +train_csv: !ref /train-clean-100.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +num_layers_ssl: 13 +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0002 +lr_weights: 0.01 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +language_modelling: False +#ngram_lm_path: !PLACEHOLDER # path/to/4-gram.arpa + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 30 +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.lobes.models.ContextNet.ContextNet + input_shape: [null, null, !ref ] + strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + +# only unitary strides to keep the frame rate + + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 640 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/LibriSpeech/contextnet/train.py b/benchmarks/MP3S/LibriSpeech/contextnet/train.py new file mode 100644 index 0000000..9d0359f --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/contextnet/train.py @@ -0,0 +1,342 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for training an SSL-based ctc ASR system with librispeech. + Decoding is performed with ctc greedy or LM-rescored decoder. +""" + +import os +import sys +import torch +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + # Forward pass + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + y = self.modules.enc(feats) + # Compute outputs + p_tokens = None + logits = self.modules.ctc_lin(y) + p_ctc = self.hparams.log_softmax(logits) + if stage != sb.Stage.TRAIN: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + loss = loss_ctc + + if stage == sb.Stage.VALID: + # Decode token terms to words + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + elif stage == sb.Stage.TEST: + if self.hparams.language_modelling: + predicted_words = [] + for logs in p_ctc: + text = decoder.decode(logs.detach().cpu().numpy()) + predicted_words.append(text.split(" ")) + else: + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + predictions = self.compute_forward(batch, stage=stage) + with torch.no_grad(): + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss.detach() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + stage_stats["loss"] + ) + old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( + stage_stats["loss"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr_model + ) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_weights + ) + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + # test is separate + test_datasets = {} + for csv_file in hparams["test_csv"]: + name = Path(csv_file).stem + test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_file, replacements={"data_root": data_folder} + ) + test_datasets[name] = test_datasets[name].filtered_sorted( + sort_key="duration" + ) + + datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + label_encoder = sb.dataio.encoder.CTCTextEncoder() + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides( + "wrd", "char_list", "tokens_list", "tokens" + ) + def text_pipeline(wrd): + yield wrd + char_list = list(wrd) + yield char_list + tokens_list = label_encoder.encode_sequence(char_list) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + special_labels = { + "blank_label": hparams["blank_index"], + "unk_label": hparams["unk_index"], + } + label_encoder.load_or_create( + path=lab_enc_file, + from_didatasets=[train_data], + output_key="char_list", + special_labels=special_labels, + sequence_input=True, + ) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "char_list", "tokens"], + ) + return train_data, valid_data, test_datasets, label_encoder + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from librispeech_prepare import prepare_librispeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_librispeech, + kwargs={ + "data_folder": hparams["data_folder"], + "tr_splits": hparams["train_splits"], + "dev_splits": hparams["dev_splits"], + "te_splits": hparams["test_splits"], + "save_folder": hparams["output_folder"], + "merge_lst": hparams["train_splits"], + "merge_name": "train.csv", + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + train_data, valid_data, test_datasets, label_encoder = dataio_prepare( + hparams + ) + # Loading the labels for the LM decoding and the CTC decoder + if "language_modelling" in hparams: + + if hparams["language_modelling"]: + from pyctcdecode import build_ctcdecoder + + ind2lab = label_encoder.ind2lab + labels = [ind2lab[x] for x in range(len(ind2lab))] + labels = [""] + labels[ + 1: + ] # Replace the token with a blank character, needed for PyCTCdecode + decoder = build_ctcdecoder( + labels, + kenlm_model_path=hparams[ + "ngram_lm_path" + ], # either .arpa or .bin file + alpha=0.5, # tuned on a val set + beta=1.0, # tuned on a val set + ) + else: + hparams["language_modelling"] = False + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # We dynamicaly add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + asr_brain.tokenizer = label_encoder + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Testing + for k in test_datasets.keys(): # keys are test_clean, test_other etc + asr_brain.hparams.wer_file = os.path.join( + hparams["output_folder"], "wer_{}.txt".format(k) + ) + asr_brain.evaluate( + test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"] + ) diff --git a/benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml b/benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml new file mode 100644 index 0000000..63fbc68 --- /dev/null +++ b/benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml @@ -0,0 +1,201 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream Intent classification on SLURP + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/SLURP/LSTM_linear +# Copy this file under benchmarks/MP3S/SLURP/LSTM_linear/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/SLURP/ +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +# The SLURP dataset will be automatically downloaded in the specified folder +data_folder: !PLACEHOLDER +# data_folder_rirs: !ref +train_splits: ["train_real"] +csv_train: !ref /train-type=direct.csv +csv_valid: !ref /devel-type=direct.csv +csv_test: !ref /test-type=direct.csv +skip_prep: False + +compute_cost: !name:speechbrain.nnet.losses.nll_loss + + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +batch_size: 2 +test_batch_size: 2 +lr: 0.0002 +lr_weights: 0.01 +# token_type: unigram # ["unigram", "bpe", "char"] +sorting: random +ckpt_interval_minutes: 5 # save checkpoint every N min + +# Model parameters +sample_rate: 16000 +emb_size: 128 +dec_neurons: 512 +output_neurons: 18 # index(eos/bos) = 0 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: 2 # 2 on linux but 0 works on windows + drop_last: False + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +enc: !new:speechbrain.nnet.containers.Sequential + input_shape: [null, null, !ref ] + lstm: !new:speechbrain.nnet.RNN.LSTM + input_size: !ref + bidirectional: True + hidden_size: !ref + num_layers: 2 + linear: !new:speechbrain.nnet.linear.Linear + input_size: !ref * 2 + n_neurons: !ref + +# Decoding parameters +bos_index: 0 +eos_index: 0 +min_decode_ratio: 0.0 +max_decode_ratio: 10.0 +slu_beam_size: 80 +eos_threshold: 1.5 +temperature: 1.25 + +dataloader_opts: + batch_size: !ref + shuffle: True + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + + +# Models +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling + return_std: False + +output_mlp: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: 18 + bias: False + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +modules: + enc: !ref + avg_pool: !ref + output_mlp: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +tokenizer: !new:sentencepiece.SentencePieceProcessor + +error_stats: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +seq_cost: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.1 + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/SLURP/LSTM_linear/train.py b/benchmarks/MP3S/SLURP/LSTM_linear/train.py new file mode 100644 index 0000000..307f648 --- /dev/null +++ b/benchmarks/MP3S/SLURP/LSTM_linear/train.py @@ -0,0 +1,318 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for "direct" (speech -> scenario) "Intent" classification using SLURP Dataset. +18 Scenarios classes are present in SLURP (calendar, email) +We encode input waveforms into features using a SSL encoder. +The probing is done using a RNN layer followed by a linear classifier. +""" + + +import os +import sys +from hyperpyyaml import load_hyperpyyaml +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main + + +class IntentIdBrain(sb.Brain): + def compute_forward(self, batch, stage): + """Computation pipeline based on a encoder + emotion classifier.""" + + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + + # last dim will be used for AdaptativeAVG pool + outputs = self.modules.enc(feats) + outputs = self.hparams.avg_pool(outputs, wav_lens) + outputs = outputs.view(outputs.shape[0], -1) + outputs = self.modules.output_mlp(outputs) + + outputs = self.hparams.log_softmax(outputs) + return outputs + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss using speaker-id as label.""" + scenario_id, _ = batch.scenario_encoded + scenario_id = scenario_id.squeeze(1) + loss = self.hparams.compute_cost(predictions, scenario_id) + if stage != sb.Stage.TRAIN: + self.error_metrics.append(batch.id, predictions, scenario_id) + + return loss + + def fit_batch(self, batch): + """Trains the parameters given a single batch in input""" + + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + + return loss.detach() + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch. + Arguments + --------- + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + epoch : int + The currently-starting epoch. This is passed + `None` during the test stage. + """ + + # Set up statistics trackers for this stage + self.loss_metric = sb.utils.metric_stats.MetricStats( + metric=sb.nnet.losses.nll_loss + ) + + # Set up evaluation-only statistics trackers + if stage != sb.Stage.TRAIN: + self.error_metrics = self.hparams.error_stats() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of an epoch. + Arguments + --------- + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST + stage_loss : float + The average loss for all of the data processed in this stage. + epoch : int + The currently-starting epoch. This is passed + `None` during the test stage. + """ + + # Store the train loss until the validation stage. + if stage == sb.Stage.TRAIN: + self.train_loss = stage_loss + + # Summarize the statistics from the stage for record-keeping. + else: + stats = { + "loss": stage_loss, + "error_rate": self.error_metrics.summarize("average"), + } + + # At the end of validation... + if stage == sb.Stage.VALID: + old_lr, new_lr = self.hparams.lr_annealing_model( + stats["error_rate"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr + ) + + ( + old_lr_encoder, + new_lr_encoder, + ) = self.hparams.lr_annealing_weights(stats["error_rate"]) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_encoder + ) + + # The train_logger writes a summary to stdout and to the logfile. + self.hparams.train_logger.log_stats( + {"Epoch": epoch, "lr": old_lr, "wave2vec_lr": old_lr_encoder}, + train_stats={"loss": self.train_loss}, + valid_stats=stats, + ) + + # Save the current checkpoint and delete previous checkpoints, + self.checkpointer.save_and_keep_only( + meta=stats, min_keys=["error_rate"] + ) + + # We also write statistics about test data to stdout and to logfile. + if stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + {"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stats, + ) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prep(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined + functions. We expect `prepare_mini_librispeech` to have been called before + this, so that the `train.json`, `valid.json`, and `valid.json` manifest + files are available. + Arguments + --------- + hparams : dict + This dictionary is loaded from the `train.yaml` file, and it includes + all the hyperparameters needed for dataset construction and loading. + Returns + ------- + datasets : dict + Contains two keys, "train" and "valid" that correspond + to the appropriate DynamicItemDataset object. + """ + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["csv_train"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["csv_valid"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["csv_test"], replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Define audio pipeline + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + """Load the signal, and pass it and its length to the corruption class. + This is done on the CPU in the `collate_fn`.""" + sig = sb.dataio.dataio.read_audio(wav) + return sig + + # Initialization of the label encoder. The label encoder assignes to each + # of the observed label a unique index (e.g, 'spk01': 0, 'spk02': 1, ..) + label_encoder = sb.dataio.encoder.CategoricalEncoder() + + # Define label pipeline: + @sb.utils.data_pipeline.takes("semantics") + @sb.utils.data_pipeline.provides("scenario", "scenario_encoded") + def label_pipeline(semantics): + scenario = semantics.split("'")[3] + yield scenario + scenario_encoded = label_encoder.encode_label_torch(scenario) + yield scenario_encoded + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + sb.dataio.dataset.add_dynamic_item(datasets, label_pipeline) + # Define datasets. We also connect the dataset with the data processing + # functions defined above. + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "scenario", "scenario_encoded"], + ) + # Load or compute the label encoder (with multi-GPU DDP support) + # Please, take a look into the lab_enc_file to see the label to index + # mappinng. + + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + label_encoder.load_or_create( + path=lab_enc_file, from_didatasets=[datasets[0]], output_key="scenario", + ) + + return {"train": datasets[0], "valid": datasets[1], "test": datasets[2]} + + +# RECIPE BEGINS! +if __name__ == "__main__": + # Reading command line arguments. + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Initialize ddp (useful only for multi-GPU DDP training). + sb.utils.distributed.ddp_init_group(run_opts) + + # Load hyperparameters file with command-line overrides. + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + from prepare import prepare_SLURP # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_SLURP, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["output_folder"], + "train_splits": hparams["train_splits"], + "slu_type": "direct", + "skip_prep": hparams["skip_prep"], + }, + ) + + # Data preparation, to be run on only one process. + # Create dataset objects "train", "valid", and "test". + datasets = dataio_prep(hparams) + + # freeze the feature extractor part when unfreezing + + # Initialize the Brain object to prepare for mask training. + ic_id_brain = IntentIdBrain( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # The `fit()` method iterates the training loop, calling the methods + # necessary to update the parameters of the model. Since all objects + # with changing state are managed by the Checkpointer, training can be + # stopped at any point, and will be resumed on next call. + ic_id_brain.fit( + epoch_counter=ic_id_brain.hparams.epoch_counter, + train_set=datasets["train"], + valid_set=datasets["valid"], + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Load the best checkpoint for evaluation + test_stats = ic_id_brain.evaluate( + test_set=datasets["test"], + min_key="error_rate", + test_loader_kwargs=hparams["test_dataloader_opts"], + ) diff --git a/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml new file mode 100644 index 0000000..64ccf6c --- /dev/null +++ b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml @@ -0,0 +1,195 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASV on VoxCeleb + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/VoxCeleb1/ecapa_tdnn +# Copy this file under benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Basic parameters +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/VoxCeleb1/ +save_folder: !ref /save +train_log: !ref /train_log.txt +# Data files +data_folder: !PLACEHOLDER # e.g. /path/to/Voxceleb +train_annotation: !ref /train.csv +valid_annotation: !ref /dev.csv +num_layers_ssl: 13 +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 + +verification_file: !PLACEHOLDER #path/to/veri_test2.txt + +skip_prep: False +ckpt_interval_minutes: 15 # save checkpoint every N min +pretrain: True +# Training parameters +number_of_epochs: 15 +batch_size: 8 +lr: 0.001 +lr_final: 0.0001 +mask_length: 10 +mask_prob: 0.65 +lr_weights: 0.01 +sample_rate: 16000 +shuffle: True +random_chunk: True +sentence_len: 3 + +# Feature parameters + +encoder_dim: 768 + +# Number of speakers +out_n_neurons: 1211 #1211 for vox1 # 5994 for vox2, 7205 for vox1+vox2 + +freeze_wav2vec: True +dataloader_options: + batch_size: !ref + shuffle: !ref + drop_last: True + +embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN + input_size: !ref + channels: [512, 512, 512, 512, 1536] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + groups: [1, 1, 1, 1, 1] + attention_channels: 64 + lin_neurons: 512 + +classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier + input_size: 512 + out_neurons: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + + +modules: + weighted_ssl_model: !ref + embedding_model: !ref + classifier: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + + +# Cost + optimization +compute_error: !name:speechbrain.nnet.losses.classification_error + +compute_cost: !new:speechbrain.nnet.losses.LogSoftmaxWrapper + loss_fn: !new:speechbrain.nnet.losses.AdditiveAngularMargin + margin: 0.2 + scale: 30 + + +model_opt_class: !name:torch.optim.Adam + lr: !ref + weight_decay: 0.000002 + +lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler + initial_value: !ref + final_value: !ref + epoch_count: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +# Used to load the SB checkpoint of w2v2 +pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer + collect_in: !ref + + +# Logging + checkpoints +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_stats: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + embedding_model: !ref + scheduler_wav2vec: !ref + ssl_model: !ref + classifier: !ref + counter: !ref + +train_data: !ref /train.csv +enrol_data: !ref /enrol.csv +test_data: !ref /test.csv + +verif_batch_size: 2 +n_train_snts: 300000 # used for normalization stats + + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +enrol_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +mean_var_norm_emb: !new:speechbrain.processing.features.InputNormalization + norm_type: global + std_norm: False diff --git a/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py new file mode 100644 index 0000000..5ba9a03 --- /dev/null +++ b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py @@ -0,0 +1,468 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for training then testing speaker embeddings using the VoxCeleb Dataset. +The embeddings are learned using the ECAPA-TDNN architecture +""" + +import os +from tqdm import tqdm +import sys +import logging +import random +import torch +import torchaudio +import speechbrain as sb +from speechbrain.utils.data_utils import download_file +from hyperpyyaml import load_hyperpyyaml +from speechbrain.utils.metric_stats import EER, minDCF +from speechbrain.utils.distributed import run_on_main + + +def compute_embedding(wavs, wav_lens): + """Compute speaker embeddings. + + Arguments + --------- + wavs : Torch.Tensor + Tensor containing the speech waveform (batch, time). + Make sure the sample rate is fs=16000 Hz. + wav_lens: Torch.Tensor + Tensor containing the relative length for each sentence + in the length (e.g., [0.8 0.6 1.0]) + """ + with torch.no_grad(): + wavs, wav_lens = ( + wavs.to(speaker_brain.device), + wav_lens.to(speaker_brain.device), + ) + feats = speaker_brain.modules.weighted_ssl_model(wavs, wav_lens) + embeddings = speaker_brain.modules.embedding_model(feats, wav_lens) + return embeddings.squeeze(1) + + +def compute_embedding_loop(data_loader): + """Computes the embeddings of all the waveforms specified in the + dataloader. + """ + embedding_dict = {} + + with torch.no_grad(): + for batch in tqdm(data_loader, dynamic_ncols=True): + batch = batch.to(hparams["device"]) + seg_ids = batch.id + wavs, lens = batch.sig + + found = False + for seg_id in seg_ids: + if seg_id not in embedding_dict: + found = True + if not found: + continue + wavs, lens = wavs.to(hparams["device"]), lens.to(hparams["device"]) + emb = compute_embedding(wavs, lens).unsqueeze(1) + for i, seg_id in enumerate(seg_ids): + embedding_dict[seg_id] = emb[i].detach().clone() + return embedding_dict + + +def get_verification_scores(veri_test): + """ Computes positive and negative scores given the verification split. + """ + scores = [] + positive_scores = [] + negative_scores = [] + + save_file = os.path.join(hparams["output_folder"], "scores.txt") + s_file = open(save_file, "w") + + # Cosine similarity initialization + similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + # creating cohort for score normalization + if "score_norm" in hparams: + train_cohort = torch.stack(list(train_dict.values())) + + for i, line in enumerate(veri_test): + + # Reading verification file (enrol_file test_file label) + lab_pair = int(line.split(" ")[0].rstrip().split(".")[0].strip()) + enrol_id = line.split(" ")[1].rstrip().split(".")[0].strip() + test_id = line.split(" ")[2].rstrip().split(".")[0].strip() + enrol = enrol_dict[enrol_id] + test = test_dict[test_id] + + if "score_norm" in hparams: + # Getting norm stats for enrol impostors + enrol_rep = enrol.repeat(train_cohort.shape[0], 1, 1) + score_e_c = similarity(enrol_rep, train_cohort) + + if "cohort_size" in hparams: + score_e_c = torch.topk( + score_e_c, k=hparams["cohort_size"], dim=0 + )[0] + + mean_e_c = torch.mean(score_e_c, dim=0) + std_e_c = torch.std(score_e_c, dim=0) + + # Getting norm stats for test impostors + test_rep = test.repeat(train_cohort.shape[0], 1, 1) + score_t_c = similarity(test_rep, train_cohort) + + if "cohort_size" in hparams: + score_t_c = torch.topk( + score_t_c, k=hparams["cohort_size"], dim=0 + )[0] + + mean_t_c = torch.mean(score_t_c, dim=0) + std_t_c = torch.std(score_t_c, dim=0) + + # Compute the score for the given sentence + score = similarity(enrol, test)[0] + + # Perform score normalization + if "score_norm" in hparams: + if hparams["score_norm"] == "z-norm": + score = (score - mean_e_c) / std_e_c + elif hparams["score_norm"] == "t-norm": + score = (score - mean_t_c) / std_t_c + elif hparams["score_norm"] == "s-norm": + score_e = (score - mean_e_c) / std_e_c + score_t = (score - mean_t_c) / std_t_c + score = 0.5 * (score_e + score_t) + + # write score file + s_file.write("%s %s %i %f\n" % (enrol_id, test_id, lab_pair, score)) + scores.append(score) + + if lab_pair == 1: + positive_scores.append(score) + else: + negative_scores.append(score) + + s_file.close() + return positive_scores, negative_scores + + +def dataio_prep_verif(params): + "Creates the dataloaders and their data processing pipelines." + + data_folder = params["data_folder"] + + # 1. Declarations: + + # Train data (used for normalization) + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=params["train_data"], replacements={"data_root": data_folder}, + ) + train_data = train_data.filtered_sorted( + sort_key="duration", select_n=params["n_train_snts"] + ) + + # Enrol data + enrol_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=params["enrol_data"], replacements={"data_root": data_folder}, + ) + enrol_data = enrol_data.filtered_sorted(sort_key="duration") + + # Test data + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=params["test_data"], replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, enrol_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav", "start", "stop") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav, start, stop): + start = int(start) + stop = int(stop) + num_frames = stop - start + sig, fs = torchaudio.load( + wav, num_frames=num_frames, frame_offset=start + ) + sig = sig.transpose(0, 1).squeeze(1) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Set output: + sb.dataio.dataset.set_output_keys(datasets, ["id", "sig"]) + + # 4 Create dataloaders + train_dataloader = sb.dataio.dataloader.make_dataloader( + train_data, **params["train_dataloader_opts"] + ) + enrol_dataloader = sb.dataio.dataloader.make_dataloader( + enrol_data, **params["enrol_dataloader_opts"] + ) + test_dataloader = sb.dataio.dataloader.make_dataloader( + test_data, **params["test_dataloader_opts"] + ) + + return train_dataloader, enrol_dataloader, test_dataloader + + +class SpeakerBrain(sb.core.Brain): + """Class for speaker embedding training" + """ + + def compute_forward(self, batch, stage): + """Computation pipeline based on a encoder + speaker classifier. + Data augmentation and environmental corruption are applied to the + input speech. + """ + batch = batch.to(self.device) + wavs, lens = batch.sig + feats = self.modules.weighted_ssl_model(wavs, lens) + # Embeddings + speaker classifier + embeddings = self.modules.embedding_model(feats) + outputs = self.modules.classifier(embeddings) + return outputs, lens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss using speaker-id as label. + """ + predictions, lens = predictions + uttid = batch.id + spkid, _ = batch.spk_id_encoded + + loss = self.hparams.compute_cost(predictions, spkid, lens) + + if stage == sb.Stage.TRAIN and hasattr( + self.hparams.lr_annealing, "on_batch_end" + ): + self.hparams.lr_annealing.on_batch_end(self.model_optimizer) + + if stage != sb.Stage.TRAIN: + self.error_metrics.append(uttid, predictions, spkid, lens) + + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of an epoch.""" + if stage != sb.Stage.TRAIN: + self.error_metrics = self.hparams.error_stats() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["ErrorRate"] = self.error_metrics.summarize("average") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr, new_lr = self.hparams.lr_annealing(epoch) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr + ) + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr": old_lr}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"ErrorRate": stage_stats["ErrorRate"]}, + min_keys=["ErrorRate"], + ) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prep(hparams): + "Creates the datasets and their data processing pipelines." + + data_folder = hparams["data_folder"] + + # 1. Declarations: + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_annotation"], + replacements={"data_root": data_folder}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_annotation"], + replacements={"data_root": data_folder}, + ) + + datasets = [train_data, valid_data] + label_encoder = sb.dataio.encoder.CategoricalEncoder() + + snt_len_sample = int(hparams["sample_rate"] * hparams["sentence_len"]) + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav", "start", "stop", "duration") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav, start, stop, duration): + if hparams["random_chunk"]: + duration_sample = int(duration * hparams["sample_rate"]) + start = random.randint(0, duration_sample - snt_len_sample) + stop = start + snt_len_sample + else: + start = int(start) + stop = int(stop) + num_frames = stop - start + sig, fs = torchaudio.load( + wav, num_frames=num_frames, frame_offset=start + ) + sig = sig.transpose(0, 1).squeeze(1) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("spk_id") + @sb.utils.data_pipeline.provides("spk_id", "spk_id_encoded") + def label_pipeline(spk_id): + yield spk_id + spk_id_encoded = label_encoder.encode_sequence_torch([spk_id]) + yield spk_id_encoded + + sb.dataio.dataset.add_dynamic_item(datasets, label_pipeline) + + # 3. Fit encoder: + # Load or compute the label encoder (with multi-GPU DDP support) + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + label_encoder.load_or_create( + path=lab_enc_file, from_didatasets=[train_data], output_key="spk_id", + ) + + # 4. Set output: + sb.dataio.dataset.set_output_keys(datasets, ["id", "sig", "spk_id_encoded"]) + + return train_data, valid_data, label_encoder + + +if __name__ == "__main__": + + logger = logging.getLogger(__name__) + # This flag enables the inbuilt cudnn auto-tuner + torch.backends.cudnn.benchmark = True + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Initialize ddp (useful only for multi-GPU DDP training) + sb.utils.distributed.ddp_init_group(run_opts) + + # Load hyperparameters file with command-line overrides + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Download verification list (to exlude verification sentences from train) + veri_file_path = os.path.join( + hparams["save_folder"], os.path.basename(hparams["verification_file"]) + ) + download_file(hparams["verification_file"], veri_file_path) + + # Dataset prep (parsing VoxCeleb and annotation into csv files) + from voxceleb_prepare import prepare_voxceleb # noqa + + prepare_voxceleb( + data_folder=hparams["data_folder"], + save_folder=hparams["save_folder"], + verification_pairs_file=veri_file_path, + splits=["train", "dev", "test"], + split_ratio=[90, 10], + seg_dur=hparams["sentence_len"], + source=hparams["voxceleb_source"] + if "voxceleb_source" in hparams + else None, + ) + + # Loading wav2vec2.0 + if not hparams["pretrain"]: + run_on_main(hparams["pretrainer"].collect_files) + hparams["pretrainer"].load_collected() + + # Dataset IO prep: creating Dataset objects and proper encodings for phones + train_data, valid_data, label_encoder = dataio_prep(hparams) + + # Create experiment directory + sb.core.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Brain class initialization + speaker_brain = SpeakerBrain( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Training + speaker_brain.fit( + speaker_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["dataloader_options"], + valid_loader_kwargs=hparams["dataloader_options"], + ) + + # Now preparing for test : + hparams["device"] = speaker_brain.device + + speaker_brain.modules.eval() + train_dataloader, enrol_dataloader, test_dataloader = dataio_prep_verif( + hparams + ) + # Computing enrollment and test embeddings + logger.info("Computing enroll/test embeddings...") + + # First run + enrol_dict = compute_embedding_loop(enrol_dataloader) + test_dict = compute_embedding_loop(test_dataloader) + + if "score_norm" in hparams: + train_dict = compute_embedding_loop(train_dataloader) + + # Compute the EER + logger.info("Computing EER..") + # Reading standard verification split + with open(veri_file_path) as f: + veri_test = [line.rstrip() for line in f] + + positive_scores, negative_scores = get_verification_scores(veri_test) + del enrol_dict, test_dict + + eer, th = EER(torch.tensor(positive_scores), torch.tensor(negative_scores)) + logger.info("EER(%%)=%f", eer * 100) + + min_dcf, th = minDCF( + torch.tensor(positive_scores), torch.tensor(negative_scores) + ) + # Testing + logger.info("minDCF=%f", min_dcf * 100) diff --git a/recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml b/recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml new file mode 100644 index 0000000..0b37fbb --- /dev/null +++ b/recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml @@ -0,0 +1,180 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is the pre-training configurations of SummaryMixing wav2vec 2.0. + +# Usage: Install SpeechBrain +# Create a folder recipes/Libri-Light/self-supervised-learning/wav2vec2 +# Copy this file under recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# NOTE! to pre-train on Libri-Light, you need to firstly download the Libri-Light data +# through the toolkit in the Libri-Light github repo +# then, use the vad script from the Libri-Light github repo to apply vad to the data +# then, use "make_librilight_csv.py" to make the csv file for SpeechBrain pretraining +# after prepared the data and the csv file, use "python train_sb_wav2vec2_mel.py hparams/summarymixing_wav2vec2.yaml" + +data_folder: # path to the libri-light folder +output_folder: results/summix_w2v2/libri-light-medium +save_folder: !ref /save +# Logging file for every N optimizer steps (many lines) +train_steps_log: !ref /train_steps_log.txt +# Logging file per epoch +train_stage_log: !ref /train_stage_log.txt + +train_csv: # path to train.csv NOTE! please refer to make_librilight_csv.py to see how to create train.csv for libri-light +valid_csv: # path to LibriSpeech dev-clean.csv We use dev-clean to monitor the pre-training, not to do model selection +skip_prep: True # train_csv and valid_csv should be created before runing the SSL pre-training + +avoid_if_longer_than: 30.0 +avoid_if_shorter_than: 10.0 # after vad most utterances are quite long +log_interval: 1000 # Logging every N optimizer steps +precision: fp16 # bf16, fp16 or fp32 +max_grad_norm: 100. + +# The training will either stops at number_of_epochs or optimizer_step_limit +# I.e. the first that is reached. +number_of_epochs: 3000 +optimizer_step_limit: 300000 + +# Dynamic Batching parameters +max_batch_length: 360 +num_buckets: 70 +shuffle: True # if true re-creates batches at each epoch shuffling examples. +batch_ordering: random +grad_accumulation_factor: 4 # we use 4 GPUs. 360s batch * 4 grad_acc * 4 GPUs = 1.6h +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + +train_dataloader_options: + num_workers: 4 + +test_dataloader_options: + batch_size: 8 # DynamicBatching not used at testing time + num_workers: 4 + +# Training parameters +lr: 0.0005 +warmup: 30000 +# This is equivalent to optimizer_step_limit - warmup +# Necessary to do to have a linear warmup and linear decay directly +# If cooldown < optimizer_step_limit - warmup then a third step with a slower +# decay is applied in the middle (see the implementation of the scheduler) +cooldown: 270000 + +# Loss parameters +diversity_loss_weight: 0.1 +mask_prob: 0.65 +mask_length: 10 +num_negatives: 100 + +# Model parameters +frontend_type: mel_cnn_base +embedding_dim: 768 +extractor_dim: 512 +final_dim: 256 +encoder_layerdrop: 0.05 +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] + +# Feature parameters +sample_rate: 16000 +n_fft: 400 +n_mels: 80 +hop_length: 10 +win_length: 25 + +optimizer: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.05 + eps: 0.000001 + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [!ref , !ref ] + input_dim: 80 + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: !ref + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: !ref + attention_type: SummaryMixing #SummaryMixing regularMHA + local_proj_hid_dim: [!ref ] + local_proj_out_dim: !ref + summary_hid_dim: [!ref ] + mode: "SummaryMixing" + +encoder_wrapper: !new:speechbrain.lobes.models.wav2vec.EncoderWrapper + in_dim: !ref + embedding_dim: !ref + latent_encoder: !ref + dropout_encoder_input: 0.1 + +target_quantiser: !new:speechbrain.lobes.models.wav2vec.W2VTargetQuantiser + in_dim: !ref + out_dim: !ref + +feat_proj: !new:torch.nn.Linear + in_features: !ref + out_features: !ref + +modules: + compute_features: !ref + normalize: !ref + CNN: !ref + latent_encoder: !ref + feat_proj: !ref + target_quantiser: !ref + +loss: !new:speechbrain.nnet.losses.ContrastiveLoss + logit_temp: 0.1 + +lr_scheduler: !new:speechbrain.nnet.schedulers.WarmCoolDecayLRSchedule + lr: !ref + warmup: !ref + cooldown: !ref + total_steps: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: sentence + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + hop_length: !ref + win_length: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + CNN: !ref + latent_encoder: !ref + feat_proj: !ref + target_quantiser: !ref + scheduler: !ref + counter: !ref + +train_steps_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +train_stage_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref \ No newline at end of file diff --git a/recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py b/recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py new file mode 100644 index 0000000..d569d79 --- /dev/null +++ b/recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py @@ -0,0 +1,90 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Usage: Install SpeechBrain + Create a folder recipes/Libri-Light/self-supervised-learning/wav2vec2 + Copy this file under recipes/Libri-Light/self-supervised-learning/wav2vec2 + +This is script create speebrain csv files for the Libri-Light dataset +1. download the Libri-Light dataset through the toolkit in the Libri-Light github repo +2. use the vad script from the Libri-Light repo to do the vad +3. use "python make_librilight_csv.py path_to_vad_output path_to_save_csv" to generate the train.csv for the SSL pretraining + + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +Authors + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 +""" + + +import pathlib +import torchaudio +import tqdm +import multiprocessing +import csv +from pathlib import Path +import sys +import os + +def make_csv_for_each(subpath_1_csv_file_folder, max_length=20.2): + subpath_1, csv_file_folder = subpath_1_csv_file_folder + for i, flac_file in enumerate(subpath_1.glob('**/*.flac')): + flac_file_name = flac_file.stem + waveform, sample_rate = torchaudio.load(str(flac_file)) + num_frames = waveform.size(1) + duration_seconds = num_frames / sample_rate + if duration_seconds > max_length: + continue + audio_length_seconds = waveform.shape[1] / sample_rate + csv_file = f"{csv_file_folder}/{flac_file.parent.stem}.csv" + with open(csv_file, mode='a', newline='') as csvfile: + csv_writer = csv.writer( + csvfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + csv_writer.writerow([flac_file_name,audio_length_seconds,str(flac_file)]) + + +def processes_folder(data_path, csv_file_folder): + os.makedirs(csv_file_folder, exist_ok=True) + os.makedirs(f"{csv_file_folder}/tmp", exist_ok=True) + list_dir = pathlib.Path(data_path) + tasks = [] + for x in list_dir.iterdir(): + tasks.append((x, f"{csv_file_folder}/tmp")) + with multiprocessing.Pool(processes=128) as pool: + for _ in tqdm.tqdm(pool.imap_unordered(make_csv_for_each, tasks), total=len(tasks)): + pass + +def merge_csv_files(csv_file_folder): + file_list = [str(x) for x in Path(f"{csv_file_folder}/tmp").glob('*.csv')] + output_file = f"{csv_file_folder}/train.csv" + fieldnames = ["ID", "duration", "wav"] + + with open(output_file, mode='a', newline='', encoding='utf-8') as outfile: + csv_writer = csv.writer( + outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + csv_writer.writerow(fieldnames) + for file_path in tqdm.tqdm(file_list): + with open(file_path, mode='r', encoding='utf-8') as infile: + reader = csv.reader(infile) + + # filter out bad rows + for row in reader: + if len(row) == 3 and os.path.exists(row[-1]): + new_row = [row[-1], row[1], row[2]] + csv_writer.writerow(new_row) + else: + print(f"bad row {row}") + + import shutil + shutil.rmtree(f"{csv_file_folder}/tmp") + + +if __name__ == "__main__": + processes_folder(sys.argv[1], sys.argv[2]) + merge_csv_files(sys.argv[2]) diff --git a/recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py b/recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py new file mode 100644 index 0000000..a3fc91b --- /dev/null +++ b/recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py @@ -0,0 +1,431 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +This is the pre-training recipes of SummaryMixing wav2vec 2.0. + +Usage: Install SpeechBrain + Create a folder recipes/Libri-Light/self-supervised-learning/wav2vec2 + Copy this file under recipes/Libri-Light/self-supervised-learning/wav2vec2 + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: https://arxiv.org/pdf/2407.13377 + +Authors + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 +""" + +import logging +import sys +import time +from functools import partial + +import speechbrain as sb +import torch +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel +from hyperpyyaml import load_hyperpyyaml + +from speechbrain import Stage +from speechbrain.utils.distributed import run_on_main +from speechbrain.dataio.dataloader import SaveableDataLoader +from speechbrain.dataio.sampler import DynamicBatchSampler +from speechbrain.lobes.models.wav2vec import w2v_mask_collate_fn +from speechbrain.lobes.models.wav2vec import sample_negatives +from speechbrain.core import AMPConfig + +logger = logging.getLogger(__name__) + + +class W2V2Brain(sb.core.Brain): + def compute_forward(self, batch, stage): + """Computes forward pass through wav2vec model and returns encoded and + target embeddings as well as other metrics of interest. + """ + wavs, wav_lens, mask = batch + wavs, wav_lens, mask = ( + wavs.to(self.device), + wav_lens.to(self.device), + mask.to(self.device), + ) + batch_size = wavs.size(0) + + # Mormalisation already done in dataloader + # 1. Go through features extractor + if ( + self.hparams.frontend_type == "w2v2" + ): + latents = self.modules.latent_extractor( + wavs, normalize_signal=False + ) + elif ( + self.hparams.frontend_type == "mel_cnn_base" + ): + with torch.no_grad(): + latents = self.modules.compute_features(wavs) + + latents = self.modules.normalize( + latents, wav_lens, + ).detach() + latents = self.modules.CNN(latents) + latents = latents.view(batch_size, latents.shape[1], -1) + + + # 2. Go through latent (Transformer). + results = self.modules.latent_encoder( + latents, mask=mask, wav_lens=wav_lens, + ) + + embeddings = results["embeddings"] + + # 3. Mask some of the latent and projection + embeddings = embeddings[mask] + embeddings = self.modules.feat_proj(embeddings) + results["embeddings"] = embeddings.view( + batch_size, -1, embeddings.size(1) + ) + + latents = latents[mask].view(batch_size, -1, latents.size(2)) + + # 4. Apply the quantiser as well + targets, meta = self.modules.target_quantiser(latents) + results.update(meta) + results["targets"] = targets + return results + + def compute_objectives(self, forward_outputs, batch, stage): + """Samples negatives, computes contrastive loss and accuracy. + """ + + embeddings = forward_outputs["embeddings"] + targets = forward_outputs["targets"] + + negs = sample_negatives(targets, self.hparams.num_negatives) + + loss, accuracy = self.hparams.loss(embeddings, targets, negs) + + # This is only used for logging purpose + if stage != sb.Stage.TRAIN: + self.acc_metric.append(accuracy) + + objectives = { + "loss": loss, + "accuracy": accuracy, + "num_masked": forward_outputs["num_masked"], + "ratio_masked": forward_outputs["ratio_masked"], + } + if ( + "diversity_loss" in forward_outputs + ): # only quantised model has these + objectives.update( + { + "diversity_loss": forward_outputs["diversity_loss"], + "prob_perplex": forward_outputs["prob_perplex"], + "code_perplex": forward_outputs["code_perplex"], + "num_vars": forward_outputs["num_vars"], + "temp": forward_outputs["temp"], + } + ) + + # Compute the loss given the original equation from the paper + loss = objectives["loss"] + if self.hparams.diversity_loss_weight == 0.0: + objectives["backprop_loss"] = loss + else: + objectives["backprop_loss"] = ( + loss + + objectives["diversity_loss"] + * self.hparams.diversity_loss_weight + * objectives["num_masked"] + ) + return objectives + + def fit_batch(self, batch): + amp = AMPConfig.from_name(self.precision) + should_step = (self.step % self.grad_accumulation_factor) == 0 + + # Managing automatic mixed precision + with self.no_sync(not should_step): + if self.use_amp: + with torch.autocast( + dtype=amp.dtype, device_type=torch.device(self.device).type, + ): + outputs = self.compute_forward(batch, Stage.TRAIN) + objectives = self.compute_objectives( + outputs, batch, Stage.TRAIN + ) + else: + outputs = self.compute_forward(batch, Stage.TRAIN) + objectives = self.compute_objectives( + outputs, batch, Stage.TRAIN + ) + + self.scaler.scale( + objectives["backprop_loss"] / self.grad_accumulation_factor + ).backward() + + objectives["total_loss"] = objectives["backprop_loss"].detach() + + if should_step: + self.optimizers_step() + self.on_fit_batch_end(objectives) + + return objectives["backprop_loss"].detach() + + def on_fit_batch_end(self, objectives): + """ Called after fit_batch(), updates learning rate and does per-step logging. """ + if isinstance(self.modules.target_quantiser, DistributedDataParallel): + w2v_model = self.modules.target_quantiser.module + else: + w2v_model = self.modules.target_quantiser + + w2v_model.quantiser.update_temp(self.optimizer_step) + + self.hparams.lr_scheduler(self.optimizer, self.optimizer_step) + + # Perform step-wise logging + if ( + hasattr(self.hparams, "log_interval") + and self.optimizer_step % self.hparams.log_interval == 0 + ): + + # Create a dictionary and fill it with everything we + # want to log such as contrastive loss, diversity loss, + # learning rate etc. + log_dct = { + k: (v.item() if isinstance(v, torch.Tensor) else v) + for k, v in objectives.items() + } + current_lr = self.optimizer.param_groups[0]["lr"] + log_dct["steps"] = self.optimizer_step + log_dct["lr"] = current_lr + log_dct["avg_loss"] = self.avg_train_loss + + if hasattr(self, "time_last_log"): + run_time_since_last_log = time.time() - self.time_last_log + log_dct["run_time"] = run_time_since_last_log + self.time_last_log = time.time() + + if sb.utils.distributed.if_main_process(): + self.hparams.train_steps_logger.log_stats(stats_meta=log_dct,) + + def evaluate_batch(self, batch, stage): + """ Returns accuracy on contrastive objective. """ + out = self.compute_forward(batch, stage=stage) + objectives = self.compute_objectives(out, batch, stage=stage) + return objectives["backprop_loss"].detach().cpu() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.acc_metric = [] + + def on_stage_end(self, stage, stage_loss, epoch=None): + + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + + if stage == sb.Stage.VALID: + print(self.acc_metric) + stage_stats["accuracy"] = sum(self.acc_metric) / len( + self.acc_metric + ) + + self.hparams.train_stage_logger.log_stats( + stats_meta={ + "epoch": epoch, + "steps": self.optimizer_step, + "lr": self.optimizer.param_groups[0]["lr"], + }, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + + self.checkpointer.save_and_keep_only( + end_of_epoch=True, + num_to_keep=5, + meta={"valid_loss": stage_loss}, + ) + + +def dataio_prepare(hparams): + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + # We remove longer and shorter files from the train. + train_data = train_data.filtered_sorted( + sort_key="duration", + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + key_min_value={"duration": hparams["avoid_if_shorter_than"]}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + + datasets = [train_data, valid_data] + + def get_output_lengths(input_lengths): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + for kernel_size, stride in zip( + hparams["latentextractor_kernels"], + hparams["latentextractor_strides"], + ): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + return input_lengths.to(torch.long) + + def get_output_lengths_w2v2_mel(input_lengths, hop_length): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + input_lengths = torch.floor( + ((input_lengths / 16000) / (hop_length / 1000) + 1) + ).to(torch.long) + for kernel_size, stride in zip( + hparams["latentextractor_kernels"], + hparams["latentextractor_strides"], + ): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths.to(torch.long) + + def get_output_lengths_mel(input_lengths, hop_length): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + + input_lengths = torch.floor( + ((input_lengths / 16000) / (hop_length / 1000) + 1) + ).to(torch.long) + + return input_lengths + + def get_output_lengths_mel_cnn(input_lengths, hop_length): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + # 2D CNN 32 31 + # input_lengths = torch.floor( + # ((input_lengths / 16000) / (hop_length / 1000) / 2 + 1) + # ).to(torch.long) + + # 2D CNN 32 32 + input_lengths = torch.floor( + ((input_lengths / 16000) / (hop_length / 1000) / 4 + 1) + ).to(torch.long) + return input_lengths + + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + assert sig.dim() == 1, sig.dim() + + # Audio normalization + with torch.no_grad(): + sig = F.layer_norm(sig, sig.shape) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + sb.dataio.dataset.set_output_keys(datasets, ["id", "sig"]) + + # We create the DynamicBatch Sampler + dynamic_hparams = hparams["dynamic_batch_sampler_train"] + + train_sampler = DynamicBatchSampler( + train_data, **dynamic_hparams, length_func=lambda x: x["duration"], + ) + + # We define the custom collation function that is necessary for w2v2 to + # generate masks. + if hparams["frontend_type"] == "w2v2": + w2v_mask_collate_fn_partial = partial( + w2v_mask_collate_fn, + get_out_len_fn=get_output_lengths, + mask_prob=hparams["mask_prob"], + mask_length=hparams["mask_length"], + ) + elif hparams["frontend_type"] == "mel_cnn_base": + w2v_mask_collate_fn_partial = partial( + w2v_mask_collate_fn, + hop_length=hparams["hop_length"], + get_out_len_fn=get_output_lengths_w2v2_mel, + mask_prob=hparams["mask_prob"], + mask_length=hparams["mask_length"], + ) + + train_loader_kwargs = { + "batch_sampler": train_sampler, + "collate_fn": w2v_mask_collate_fn_partial, + "num_workers": hparams["train_dataloader_options"]["num_workers"], + "pin_memory": True, + } + + valid_loader = SaveableDataLoader( + valid_data, + collate_fn=w2v_mask_collate_fn_partial, + num_workers=hparams["test_dataloader_options"]["num_workers"], + batch_size=hparams["test_dataloader_options"]["batch_size"], + pin_memory=True, + ) + + return train_data, valid_loader, train_loader_kwargs + + +def main(): + logger.setLevel(logging.INFO) + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + hparams.update(run_opts) + + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Update precision to bf16 if the device is CPU and precision is fp16 + if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16": + hparams["precision"] = "bf16" + + # Part that matters starts here. + train_dataset, valid_loader, train_loader_kwargs = dataio_prepare(hparams) + + brain = W2V2Brain( + modules=hparams["modules"], + opt_class=hparams["optimizer"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + brain.fit( + brain.hparams.epoch_counter, + train_dataset, + valid_loader, + train_loader_kwargs=train_loader_kwargs, + progressbar=True, + ) + + +if __name__ == "__main__": + main() diff --git a/recipes/LibriSpeech/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml b/recipes/LibriSpeech/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml index 162f283..7f699ce 100644 --- a/recipes/LibriSpeech/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml +++ b/recipes/LibriSpeech/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml @@ -268,6 +268,7 @@ Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.Transforme local_proj_hid_dim: !ref local_proj_out_dim: !ref summary_hid_dim: !ref + use_layernorm: False mode: !ref normalize_before: True causal: False diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/conformer_summarymixing.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/conformer_summarymixing.yaml new file mode 100644 index 0000000..01707f0 --- /dev/null +++ b/recipes/LibriSpeech/ASR/transformer/hparams/conformer_summarymixing.yaml @@ -0,0 +1,351 @@ +# ############################################################################ +# +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. +# +# Model: E2E ASR with Transformer +# Encoder: branchformer Encoder with SummaryMixing +# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM +# Tokens: unigram +# losses: CTC + KLdiv (Label Smoothing loss) +# Training: Librispeech 960h +# Authors: Titouan Parcollet, Shucong Zhang, Rogier van Dalen, and Sourav +# Bhattacharya +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 3407 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/conformer_large/ +output_wer_folder: !ref / +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Language model (LM) pretraining +# NB: To avoid mismatch, the speech recognizer must be trained with the same +# tokenizer used for LM training. Here, we download everything from the +# speechbrain HuggingFace repository. However, a local path pointing to a +# directory containing the lm.ckpt and tokenizer.ckpt may also be specified +# instead. E.g if you want to use your own LM / tokenizer. +pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES +# then data_folder_rirs should be /localscratch/xxx_corpus +# otherwise the dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] +dev_splits: ["dev-clean"] +test_splits: ["dev-clean", "test-clean", "test-other"] +skip_prep: False +train_csv: !ref /train.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /dev-clean.csv + - !ref /test-clean.csv + - !ref /test-other.csv + +####################### Training Parameters #################################### + +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. +# Empirically, we found that this value should be >= 128. +# Please, set your parameters accordingly. +number_of_epochs: 120 +batch_size: 16 # This works for 2x GPUs with 32GB +ctc_weight: 0.3 +grad_accumulation_factor: 1 +max_grad_norm: 5.0 +loss_reduction: 'batchmean' +sorting: random +num_workers: 4 +precision: fp32 # bf16, fp16 or fp32 +avg_checkpoints: 10 # Number of checkpoints to average for evaluation + +# stages related parameters +lr_adam: 0.0008 + +# Feature parameters +sample_rate: 16000 +n_fft: 512 +n_mels: 80 +win_length: 32 + +# This setup works well for A100 80GB GPU, adapts it to your needs. +# Or turn it off (but training speed will decrease) +dynamic_batching: True +max_batch_length_train: 500 +max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) +num_bucket: 200 +shuffle: True # if true re-creates batches at each epoch shuffling examples. +batch_ordering: random +max_batch_ex: 256 + +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +dynamic_batch_sampler_valid: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + max_batch_ex: !ref + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +valid_dataloader_opts: + batch_size: 1 + +test_dataloader_opts: + batch_size: 1 + +####################### Model Parameters ####################################### + +# Transformer +d_model: 512 +nhead: 8 +num_encoder_layers: 12 +num_decoder_layers: 6 +d_ffn: 2048 +transformer_dropout: 0.1 +activation: !name:torch.nn.GELU +output_neurons: 5000 +attention_type: SummaryMixing # SummaryMixing, regularMHA or RelPosMHAXL +mode: SummaryMixing # SummaryMixing or SummaryMixing-lite +local_proj_hid_dim: [512] +local_proj_out_dim: 512 +summary_hid_dim: [512] + +# Outputs +blank_index: 0 +label_smoothing: 0.1 +pad_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 1.0 +valid_search_interval: 10 +valid_beam_size: 10 +test_beam_size: 66 +lm_weight: 0.66 +ctc_weight_decode: 0.40 + +############################## Models ########################################## + +CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd + input_shape: (8, 10, 80) + num_blocks: 2 + num_layers_per_block: 1 + out_channels: (64, 32) + kernel_sizes: (3, 3) + strides: (2, 2) + residuals: (False, False) + +Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length + input_size: 640 + tgt_vocab: !ref + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: !ref + d_ffn: !ref + dropout: !ref + activation: !ref + encoder_module: conformer + attention_type: SummaryMixing + normalize_before: True + causal: False + mode: SummaryMixing # SummaryMixing or SummaryMixing-lite + local_proj_hid_dim: !ref + local_proj_out_dim: !ref + summary_hid_dim: !ref + +# This is the TransformerLM that is used according to the Huggingface repository +# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path +# For more details about the model! +# NB: It has to match the pre-trained TransformerLM!! +lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length + vocab: !ref + d_model: 768 + nhead: 12 + num_encoder_layers: 12 + num_decoder_layers: 0 + d_ffn: 3072 + dropout: 0.0 + activation: !name:torch.nn.GELU + normalize_before: False + +tokenizer: !new:sentencepiece.SentencePieceProcessor + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +seq_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: global + update_until_epoch: 4 + +modules: + CNN: !ref + Transformer: !ref + seq_lin: !ref + ctc_lin: !ref + normalize: !ref + +# define two optimizers here for two-stage training +Adam: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 0.000000001 + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +####################### Decoding & optimiser ########################### + +ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer + eos_index: !ref + blank_index: !ref + ctc_fc: !ref + + +transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer + language_model: !ref + temperature: 1.15 + +scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder + full_scorers: [!ref , !ref ] + weights: + ctc: !ref + transformerlm: !ref + +scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder + full_scorers: [!ref ] + weights: + ctc: !ref + +valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher + modules: [!ref , !ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + using_eos_threshold: False + length_normalization: True + scorer: !ref + +test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher + modules: [!ref , !ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + temperature: 1.15 + using_eos_threshold: False + length_normalization: True + scorer: !ref + +log_softmax: !new:torch.nn.LogSoftmax + dim: -1 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + reduction: !ref + +seq_cost: !name:speechbrain.nnet.losses.kldiv_loss + label_smoothing: !ref + reduction: !ref + +noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler + lr_initial: !ref + n_warmup_steps: 30000 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + noam_scheduler: !ref + normalizer: !ref + counter: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +############################## Augmentations ################################### + +# Speed perturbation +speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb + orig_freq: !ref + speeds: [95, 100, 105] + +# Time Drop +time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop + drop_length_low: 15 + drop_length_high: 25 + drop_count_low: 4 + drop_count_high: 4 + replace: "mean" + +# Freq Drop +freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop + drop_length_low: 10 + drop_length_high: 20 + drop_count_low: 4 + drop_count_high: 4 + replace: "mean" + dim: 2 + +# Time warp +time_warp: !new:speechbrain.augment.freq_domain.Warping + +fea_augment: !new:speechbrain.augment.augmenter.Augmenter + min_augmentations: 3 + max_augmentations: 3 + augment_prob: 1.0 + augmentations: [ + !ref , + !ref , + !ref ] + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + win_length: !ref + +############################## Logging and Pretrainer ########################## + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats +acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats + +# The pretrainer allows a mapping between pretrained files and instances that +# are declared in the yaml. E.g here, we will download the file lm.ckpt +# and it will be loaded into "lm" which is pointing to the defined +# before. +pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer + collect_in: !ref + loadables: + lm: !ref + tokenizer: !ref + paths: + lm: !ref /lm.ckpt + tokenizer: !ref /tokenizer.ckpt diff --git a/recipes/VoxPopuli/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml b/recipes/VoxPopuli/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml index dd3a8d8..43af295 100644 --- a/recipes/VoxPopuli/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml +++ b/recipes/VoxPopuli/ASR/transducer/hparams/conformer_summarymixing_transducer.yaml @@ -187,6 +187,7 @@ Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.Transforme local_proj_out_dim: !ref summary_hid_dim: !ref mode: !ref + use_layernorm: False normalize_before: True causal: False max_length: 6000 # For absolute positional encoding diff --git a/speechbrain/lobes/models/transformer/Conformer.py b/speechbrain/lobes/models/transformer/Conformer.py index 2d6be06..9a07e68 100755 --- a/speechbrain/lobes/models/transformer/Conformer.py +++ b/speechbrain/lobes/models/transformer/Conformer.py @@ -1,15 +1,22 @@ """ SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. +This library connects SummaryMixing to the standard SpeechBrain lobes for Branchformer ASR. +Large parts of the code come from the SpeechBrain repository. + +Usage: Install SpeechBrain and copy this file under speechbrain/lobes/models/transformer/ +Source: https://arxiv.org/abs/2307.07421 + Authors -------- -* Jianyuan Zhong 2020 -* Samuele Cornell 2021 -* Sylvain de Langen 2023 + * Titouan Parcollet 2023 + * Shucong Zhang 2023 + * Rogier van Dalen 2023 + * Sourav Bhattacharya 2023 """ import warnings from dataclasses import dataclass from typing import List, Optional +import numpy as np import torch import torch.nn as nn @@ -372,6 +379,8 @@ class ConformerEncoderLayer(nn.Module): One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. + use_layernorm: bool, optional + Using layernorm for the local and the global branch in SummaryMixing or not. Example ------- @@ -401,6 +410,7 @@ def __init__( local_proj_out_dim=512, summary_hid_dim=[1024], mode="SummaryMixing", + use_layernorm=True, ): super().__init__() @@ -442,6 +452,7 @@ def __init__( summary_out_dim=d_model, activation=activation, global_dropout=dropout, + use_layernorm=use_layernorm, mode=mode, ) self.masked_false_or_true = False @@ -679,6 +690,12 @@ class ConformerEncoder(nn.Module): One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. + use_layernorm: bool, optional + Using layernorm for the local and the global branch in SummaryMixing or not. + layerdrop_prob: float, optional + probability for layer dropout + output_hidden_states: bool, optional + if output the hidden states of all the layers Example @@ -710,6 +727,9 @@ def __init__( local_proj_out_dim=512, summary_hid_dim=[1024], mode="SummaryMixing", + use_layernorm: Optional[bool] = True, + layerdrop_prob=0.0, + output_hidden_states=False, ): super().__init__() @@ -730,6 +750,7 @@ def __init__( local_proj_hid_dim=local_proj_hid_dim, local_proj_out_dim=local_proj_out_dim, summary_hid_dim=summary_hid_dim, + use_layernorm=use_layernorm, mode=mode, ) for i in range(num_layers) @@ -737,6 +758,9 @@ def __init__( ) self.norm = LayerNorm(d_model, eps=1e-6) self.attention_type = attention_type + self.layerdrop_prob = layerdrop_prob + self.rng = np.random.default_rng() + self.output_hidden_states = output_hidden_states def forward( self, @@ -771,18 +795,35 @@ def forward( ) output = src + if self.layerdrop_prob > 0.0: + keep_probs = self.rng.random(len(self.layers)) + else: + keep_probs = None attention_lst = [] - for enc_layer in self.layers: - output, attention = enc_layer( - output, - src_mask=src_mask, - src_key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs, - dynchunktrain_config=dynchunktrain_config, - ) - attention_lst.append(attention) + if self.output_hidden_states: + hidden_lst = [] + for i, enc_layer in enumerate(self.layers): + if ( + not self.training + or self.layerdrop_prob == 0.0 + or keep_probs[i] > self.layerdrop_prob + ): + output, attention = enc_layer( + output, + src_mask=src_mask, + src_key_padding_mask=src_key_padding_mask, + pos_embs=pos_embs, + dynchunktrain_config=dynchunktrain_config, + ) + attention_lst.append(attention) + if self.output_hidden_states: + hidden_lst.append(output) output = self.norm(output) + if self.output_hidden_states: + hidden_lst[-1] = output + return output, hidden_lst, attention_lst + return output, attention_lst def forward_streaming( diff --git a/speechbrain/lobes/models/transformer/Transformer.py b/speechbrain/lobes/models/transformer/Transformer.py index bab78d9..7dbee2b 100644 --- a/speechbrain/lobes/models/transformer/Transformer.py +++ b/speechbrain/lobes/models/transformer/Transformer.py @@ -122,6 +122,8 @@ class TransformerInterface(nn.Module): One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. + use_layernorm: bool, optional + Using layernorm for the local and the global branch in SummaryMixing or not. masked_false_or_true: bool If True, masked elements will be set to True, False otherwise. """ @@ -159,6 +161,7 @@ def __init__( summary_hid_dim: Optional[list] = [1024], summary_out_dim: Optional[int] = 1024, mode: Optional[str] = "SummaryMixing", + use_layernorm: Optional[bool] = True, masked_false_or_true: Optional[bool] = True, ): super().__init__() @@ -232,6 +235,7 @@ def __init__( local_proj_hid_dim=local_proj_hid_dim, local_proj_out_dim=local_proj_out_dim, summary_hid_dim=summary_hid_dim, + use_layernorm=use_layernorm, mode=mode, ) assert normalize_before, "normalize_before must be True for Conformer" diff --git a/speechbrain/lobes/models/transformer/TransformerASR.py b/speechbrain/lobes/models/transformer/TransformerASR.py index a64b01a..266d0ad 100755 --- a/speechbrain/lobes/models/transformer/TransformerASR.py +++ b/speechbrain/lobes/models/transformer/TransformerASR.py @@ -261,6 +261,8 @@ class TransformerASR(TransformerInterface): One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. + use_layernorm: bool, optional + Using layernorm for the local and the global branch in SummaryMixing or not. masked_false_or_true: bool If True, masked elements will be set to True, False otherwise. @@ -307,6 +309,7 @@ def __init__( summary_hid_dim: Optional[list] = [1024], summary_out_dim: Optional[int] = 1024, mode: Optional[str] = "SummaryMixing", + use_layernorm: Optional[bool] = True, masked_false_or_true: Optional[bool] = True, ): super().__init__( @@ -335,6 +338,7 @@ def __init__( summary_hid_dim=summary_hid_dim, summary_out_dim=summary_out_dim, mode=mode, + use_layernorm=use_layernorm, masked_false_or_true=masked_false_or_true, ) diff --git a/speechbrain/lobes/models/wav2vec.py b/speechbrain/lobes/models/wav2vec.py new file mode 100644 index 0000000..f3f4838 --- /dev/null +++ b/speechbrain/lobes/models/wav2vec.py @@ -0,0 +1,569 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +This library contains SummaryMixing wav2vec 2.0. for pre-training and downstream tasks + +Usage: Install SpeechBrain + Copy this file under speechbrain/lobes/models + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: + +Authors + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 +""" + +import logging +import torch +import torch.nn.functional as F +import torch.nn as nn +import random +import numpy as np + +from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding +from speechbrain.utils.data_utils import batch_pad_right +from speechbrain.dataio.dataio import length_to_mask +from speechbrain.lobes.models.convolution import ConvolutionFrontEnd +from speechbrain.nnet.CNN import Conv1d +from speechbrain.nnet.normalization import LayerNorm +from speechbrain.nnet.quantisers import GumbelVectorQuantizer + +from speechbrain.processing.features import InputNormalization +from speechbrain.lobes.features import Fbank + +logger = logging.getLogger() + + +class W2VLatentExtractor(nn.Module): + """Convolution based feature extractor from raw audio. + Channel numbers increasing is based on https://arxiv.org/abs/2109.06870 + Arguments + --------- + out_channels : list of ints + Out channels of convolutional layers. + kernel_sizes : list of ints + Kernels of convolutional layers. + strides : list of ints + Strides of convolutional layers. + dropout : float + Dropout of CNN. + Example + ------- + >>> extractor = W2VLatentExtractor() + >>> inputs = torch.rand(10, 5000) + >>> outputs = extractor(inputs) + >>> outputs.shape + torch.Size([10, 14, 512]) + """ + + def __init__( + self, + out_channels=[512, 512, 512, 512, 512, 512, 512], + kernel_sizes=[11, 3, 3, 3, 3, 3, 3], + strides=[5, 2, 2, 2, 2, 2, 2], + dropout=0.0, + conv_init="kaiming", + input_dim=None, + pretrained_path=None, + ): + super().__init__() + + assert len(out_channels) == len(kernel_sizes) == len(strides) + + num_blocks = len(out_channels) + self.kernel_sizes = kernel_sizes + self.strides = strides + self.out_dim = out_channels[-1] + # ! Note this does conv, norm, gelu, dropout. while fairseq does conv, dropout, norm, gelu + # Also fairseq layernorm is forced to fp32 + if input_dim is None: + inp_shape = ( + None, + 16000, + ) + else: + inp_shape = (None, 16000, input_dim) + self.extractor = ConvolutionFrontEnd( + inp_shape, + num_blocks=num_blocks, + num_layers_per_block=1, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + strides=strides, + dilations=[1] * num_blocks, + residuals=[False] * num_blocks, + conv_module=Conv1d, + activation=nn.GELU, + norm=LayerNorm, + dropout=dropout, + conv_bias=False, + padding="valid", + conv_init=conv_init, + ) + self.norm = nn.LayerNorm(out_channels[-1]) + + if pretrained_path: + ckpt = torch.load(pretrained_path) + self.load_state_dict(ckpt) + + def forward(self, x, normalize_signal=True): + """ Calculates latents from audio input. + """ + if normalize_signal: + x = F.layer_norm(x, x.shape[1:]) + + latents = self.extractor(x) + return self.norm(latents) + + def get_output_lengths(self, input_lengths: torch.LongTensor): + """ Calculates output lengths for given input lengths. """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + for kernel_size, stride in zip(self.kernel_sizes, self.strides): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + return input_lengths.to(torch.long) + + +class W2VTargetQuantiser(nn.Module): + """ Wraps ``nnet.quantiser.GumbelVectorQuantizer``, see for documentation on + arguments. + + Example + ------- + >>> quantiser = W2VTargetQuantiser() + >>> inputs = torch.rand(10, 12, 512) + >>> output, meta = quantiser(inputs) + >>> output.shape + torch.Size([10, 12, 256]) + """ + + def __init__( + self, + in_dim=512, + out_dim=256, + quantiser=GumbelVectorQuantizer, + num_vars=320, + temperature_decay=(2.0, 0.25, 0.999995,), + ): + super().__init__() + self.quantiser = quantiser( + in_dim, num_vars, temperature_decay, 2, out_dim + ) + self.proj = nn.Linear(out_dim, out_dim) + + def forward(self, x): + """ Returns quantised targets plus meta information. """ + x = self.quantiser(x) + targets = self.proj(x["x"]) + code_perplex = x["code_perplexity"] + prob_perplex = x["prob_perplex"] + num_vars = x["num_vars"] + temp = x["temp"] + diversity_loss = (num_vars - prob_perplex) / num_vars + meta = { + "diversity_loss": diversity_loss, + "code_perplex": code_perplex, + "prob_perplex": prob_perplex, + "num_vars": num_vars, + "temp": temp, + } + return targets, meta + + +class EncoderWrapper(nn.Module): + """A wrapper that adds positional information, + masks the input and then runs the latent encoder. + Arguments + --------- + in_dim : int + Last dimension of input tensor. + embedding_dim : int + Dimension to project input to and that the latent encoder will use. + latent_encoder : torch.nn.module + Initialized latent encoder object. + positional_encoding : torch.nn.module + Uninitialized nn.module for adding positional information, will use ``embedding_dim``. + dropout_encoder_input : float + Dropout on encoder input. + + Example + ------- + >>> from speechbrain.lobes.models.transformer.Transformer import TransformerEncoder + >>> encoder = TransformerEncoder(d_model=768, num_layers=4, nhead=4, d_ffn=1024) + >>> wrapper = EncoderWrapper(1024, 768, encoder) + >>> inputs = torch.rand(10, 12, 1024) + >>> outputs = wrapper(inputs) + >>> outputs["embeddings"].shape + torch.Size([10, 12, 768]) + """ + + def __init__( + self, + in_dim, + embedding_dim, + latent_encoder, + positional_encoding=PositionalEncoding, + dropout_encoder_input=0.05, + output_hidden_states=False, + pretrained_path=None, + ): + super().__init__() + self.input_projector = nn.Linear(in_dim, embedding_dim) + self.latent_encoder = latent_encoder + self.positional_encoding = positional_encoding( + embedding_dim, max_len=3500 + ) + self.dropout_encoder_input = nn.Dropout(dropout_encoder_input) + self.mask_emb = nn.Parameter( + torch.FloatTensor(embedding_dim).uniform_(), requires_grad=True + ) + self.output_hidden_states = output_hidden_states + if pretrained_path: + ckpt = torch.load(pretrained_path) + self.load_state_dict(ckpt) + + def forward( + self, latents, wav_lens=None, padding_mask=None, mask=None, + ): + """ + Arguments + --------- + latents : torch.Tensor, shape (B, T, C) + Batch of latent representations (AKA frames) output from latent extractor. + wav_lens : torch.Tensor, shape (B,) + The actual (unpadded) relative lengths for each sample of the batch (0 + neg_indcs[neg_indcs >= targets] += 1 + + neg_indcs = neg_indcs + torch.arange(B).unsqueeze(1) * high + y = y.view(-1, C) + negs = y[neg_indcs.view(-1)] + negs = negs.view(B, T, num_neg, C).permute(2, 0, 1, 3) # to N, B, T, C + return negs + + +def w2v_mask_collate_fn(samples_lst, get_out_len_fn, mask_prob, mask_length, hop_length=None): + """ This creates a batch from a list of samples and also creates + the boolean mask that will be used to mask the inputs of the latent + encoder. To create the mask we need to know the output shape after the + latent extractor, therefore the argument `get_out_len_fn`. + One could also create masks per sample (when loading the audio file) and + then collate them but at that time one doesn't know the length of the + shortest sample in the batch (which determines the number of masked frames) + so it's better this way. + + Arguments + --------- + samples_lst : list + List of samples returned by the audio_pipeline. + get_out_len_fn : function + Function that calculates length of sample after it passes through feature extractor. + mask_prob : float + Approximate percentage of frames to mask. + mask_length : int + Number of contiguous frames that will be masked. + + Returns + ------- + wavs_padded : torch.Tensor, shape (B, T) + Audio arrays with right-sided padding. + wav_lens : torch.Tensor, shape (B,) + For each sample the percentage of the array that is not padding. + mask : torch.Tensor, shape (B, T) + Boolean mask to mask frames. + """ + wav_lst, latent_length_lst = [], [] + ids = [] + + for sample in samples_lst: + ids.append(sample["id"]) + sig = sample["sig"] + wav_lst.append(sig) + if hop_length is not None: + latent_length = get_out_len_fn( + torch.as_tensor(sig.size(-1)), hop_length + ) + else: + latent_length = get_out_len_fn(torch.as_tensor(sig.size(-1))) + latent_length_lst.append(latent_length.item()) + bs = len(wav_lst) + wavs_padded, wav_lens = batch_pad_right(wav_lst) + + batch_time_len = max(latent_length_lst) + mask = compute_mask( + (bs, batch_time_len,), latent_length_lst, mask_prob, mask_length + ) + return ( + torch.as_tensor(wavs_padded), + torch.as_tensor(wav_lens), + torch.as_tensor(mask, dtype=torch.bool), + ) + +class WeightedSSLModel(torch.nn.Module): + """This lobe enables the integration of use of weighted sum representations + from different layers in a SSL encoder. + + The model can be used as a fixed feature extractor for SSL benchmarking. It + will download automatically the model from HuggingFace or use a local path. + + More details in recipes/SSL_benchmark + + Arguments + --------- + hub : str + HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60" + num_layers: int + Number of internal layers: e.g 13 for "Base" models. + layernorm: bool + Whether layer representations should be layernormed before sum + Example + ------- + >>> inputs = torch.rand([10, 600]) + >>> model_hub = "facebook/wav2vec2-base-960h" + >>> num_layers = 13 + >>> model = WeightedSSLModel(model_hub, num_layers) + >>> outputs = model(inputs) + """ + + def __init__( + self, + num_layers, + pretrained_path, + latent_encoder, + CNN, + dropout=0.0, + conv_init="kaiming", + in_dim=512, + embedding_dim=768, + positional_encoding=PositionalEncoding, + dropout_encoder_input=0.0, + output_hidden_states=True, + layernorm=False, + full_tune=False, + sample_rate=16000, + n_fft=400, + n_mels=80, + hop_length=10, + win_length=25,): + super().__init__() + + self.compute_features = Fbank( + sample_rate=16000, + n_fft=400, + n_mels=80, + hop_length=10, + win_length=25, + ) + self.inputnorm = InputNormalization(norm_type="sentence") + self.latent_extractor = CNN + + self.encoder_wrapper = EncoderWrapper( + in_dim, + embedding_dim, + latent_encoder, + positional_encoding, + dropout_encoder_input, + output_hidden_states, + ) + + + + latent_extractor_path = f"{pretrained_path}/CNN.ckpt" + latent_encoder_path = f"{pretrained_path}/latent_encoder.ckpt" + + latent_extractor_ckpt = torch.load(latent_extractor_path) + latent_encoder_ckpt = torch.load(latent_encoder_path) + + self.latent_extractor.load_state_dict(latent_extractor_ckpt) + self.encoder_wrapper.load_state_dict(latent_encoder_ckpt) + self.output_hidden_states = output_hidden_states + + self.num_layers = num_layers + # Initializing the learnable weights + zero_init = torch.cat([torch.zeros(self.num_layers)]) + self.weights = torch.nn.Parameter(zero_init, requires_grad=True) + self.layernorm = layernorm + self.full_tune = full_tune + + def forward(self, wav, wav_lens=None): + """This method outputs a weighted sum of the layers representations of the SSL encoder + Arguments + --------- + wav : tensor + The wavs + """ + # SB mel + if not self.full_tune: + with torch.no_grad(): + latents = self.compute_features(wav) + latents = self.inputnorm(latents, wav_lens).detach() + latents = self.latent_extractor(latents) + latents = latents.view(latents.shape[0], latents.shape[1], -1) + + feats = self.encoder_wrapper(latents, wav_lens=wav_lens)[ + "embeddings" + ] + + hidden_states = torch.stack(feats, dim=0).detach() + else: + with torch.no_grad(): + latents = self.compute_features(wav) + latents = self.inputnorm(latents, wav_lens).detach() + latents = self.latent_extractor(latents) + latents = latents.view(latents.shape[0], latents.shape[1], -1) + + feats = self.encoder_wrapper(latents, wav_lens=wav_lens)[ + "embeddings" + ] + + hidden_states = torch.stack(feats, dim=0) + + + # First dimension should be equal to the number of layers in the hparams + assert ( + self.num_layers == hidden_states.shape[0] + ), f"Num layers {self.num_layers} not equal to num hidden states {hidden_states.shape[0]}" + norm_weights = torch.nn.functional.softmax(self.weights, dim=-1) + # Layernorming the layers representations if asked + if self.layernorm: + hidden_states = [ + F.layer_norm(t, (t.shape[-1],)) for t in hidden_states + ] + # Summing the weighted layers + weighted_feats = hidden_states[0] * norm_weights[0] + for i in range(1, len(hidden_states)): + weighted_feats += hidden_states[i] * norm_weights[i] + + return weighted_feats \ No newline at end of file diff --git a/speechbrain/nnet/summary_mixing.py b/speechbrain/nnet/summary_mixing.py index 6500bca..9f0c9c9 100644 --- a/speechbrain/nnet/summary_mixing.py +++ b/speechbrain/nnet/summary_mixing.py @@ -62,6 +62,8 @@ class SummaryMixing(nn.Module): according to the definition of the article. "SummaryMixing-lite" removes the local project branch. "SummaryMixing-expdecay" is another alternative using an exponential decay for the window, it's slower. + use_layernorm: bool, optional + Using layernorm for the local and the global branch in SummaryMixing or not. Example @@ -84,6 +86,7 @@ def __init__( activation: Optional[nn.Module] = nn.GELU, global_dropout: Optional[float] = 0.1, mode: Optional[str] = "SummaryMixing", + use_layernorm: Optional[bool] = True, ): super(SummaryMixing, self).__init__() @@ -107,6 +110,7 @@ def __init__( self.local_dnn_blocks = local_proj_hid_dim + [local_proj_out_dim] self.summary_dnn_blocks = summary_hid_dim + [summary_out_dim] self.mode = mode + self.use_layernorm = use_layernorm self.dropout = nn.Dropout(global_dropout) if self.mode == "SummaryMixing" or self.mode == "SummaryMixing-expdecay": @@ -156,6 +160,10 @@ def __init__( data=torch.tensor(0.995), requires_grad=False ) + if self.use_layernorm: + self.local_norm = nn.LayerNorm(local_proj_out_dim) + self.summary_norm = nn.LayerNorm(summary_out_dim) + self.apply(self._init_parameters) def forward(self, x, sum_mask=None, src_padding_mask=None): @@ -206,6 +214,9 @@ def _forward_mixing(self, x, sum_mask, src_padding_mask): # f() (Eq. 1b) local_summary = self.local_proj(x) * src_padding_mask + if self.use_layernorm: + local_summary = self.local_norm(local_summary) + # s() (Eq. 2 and 1c) time_summary = self.summary_proj(x) * src_padding_mask @@ -234,6 +245,9 @@ def _forward_mixing(self, x, sum_mask, src_padding_mask): sum_mask, dim=1 ).unsqueeze(-1) + if self.use_layernorm: + time_summary = self.summary_norm(time_summary) + return self.summary_local_merging( self.dropout(torch.cat([local_summary, time_summary], dim=-1)) ) From e77fe17c99ce4aa5763340498194ce0b1fd7f64d Mon Sep 17 00:00:00 2001 From: shucongzhang <104781888+shucongzhang@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:19:43 +0100 Subject: [PATCH 2/2] Update README.md description of SSL, streaming and Conformer SB 1.0 results --- README.md | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bce2557..5123bce 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,30 @@ # SummaryMixing for SpeechBrain v1.0 *Halve your VRAM requirements and train 30% faster any speech model achieving equivalents or better Word Error Rates and SLU accuracies with SummaryMixing Conformers and Branchformers.* +*Reduce your self-supervised learning (SSL) pre-training time and VRAM requirements by 20%-30% with equivalents or better downstream performan on speech processing tasks.* + +## In brief +SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear). + +This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the [publication](https://arxiv.org/abs/2307.07421) for further details). + +This repository also implements SummaryMixing for SSL pre-training (see: the [publication](https://arxiv.org/pdf/2407.13377) for further details) and streaming transducer. + +The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) toolkit -- copy and paste is all you need to start using SummaryMixing in your setup. + ## !! A word about using SummaryMixing with SpeechBrain V1.0 !! -The main branch of this repository will keep tracking the latest version of SpeechBrain available. Unfortunately the results reported in our [publication](https://arxiv.org/abs/2307.07421) and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated +The main branch of this repository will keep tracking the latest version of SpeechBrain available. The results for SSL in our [publication](https://arxiv.org/pdf/2407.13377) and the streaming transducer were obtained with SpeechBrain v1.0. For the Conformer attention-CTC models with SpeechBrain v1.0, below are the results: + +| Encoder | Variant | Dev-clean | Test-clean | Test-other | +|------------------|----------------------|--------------------|---------------------|---------------------| +| | | **WER \%** | **WER \%** | **WER \%** | **hours** | **GB** | +| Conformer | Self-attention | 1.9 | 2.0 | 4.6 | +| Conformer | SummaryMixing | 1.9 | 2.0 | 4.6 | + +Unfortunately the results reported in our [publication](https://arxiv.org/abs/2307.07421) and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated [branch](https://github.com/SamsungLabs/SummaryMixing/tree/speechbrain_v0.5) that contains the code compatible with SpeechBrain v0.5! -## In brief -This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the [publication](https://arxiv.org/abs/2307.07421) for further details). The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) toolkit with version 0.5 -- copy and paste is all you need to start using SummaryMixing in your setup. If you wish to run with the latest version of SpeechBrain (v1.0+), please go to the main branch of this repository. SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear). ## A glance at SummaryMixing @@ -44,13 +61,23 @@ Please cite SummaryMixing as follows: ```bibtex @misc{summarymixing, title={{SummaryMixing}: A Linear-Complexity Alternative to Self-Attention for Speech Recognition and Understanding}, - author={Titouan Parcollet and Rogier van Dalen and and Shucong Zhang and Sourav Bhattacharya}, + author={Titouan Parcollet and Rogier van Dalen and Shucong Zhang and Sourav Bhattacharya}, year={2023}, eprint={2307.07421}, archivePrefix={arXiv}, primaryClass={eess.AS}, note={arXiv:2307.07421} } + +@misc{linear_ssl, + title={Linear-Complexity Self-Supervised Learning for Speech Processing}, + author={Shucong Zhang and Titouan Parcollet and Rogier van Dalen and Sourav Bhattacharya}, + year={2024}, + eprint={2407.13377}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2407.13377} +} ``` ## Licence