From 144eee0b3281ef77772a6d1f683b64e8e00ccd83 Mon Sep 17 00:00:00 2001 From: "s1.zhang@samsung.com" Date: Thu, 20 Jun 2024 17:16:30 +0000 Subject: [PATCH] scripts for SummaryMixing SSL --- README.md | 44 +- .../MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml | 195 +++ .../MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml | 195 +++ benchmarks/MP3S/CommonVoice/LSTM/train.py | 305 +++++ .../MP3S/LibriSpeech/LSTM/hparams/ssl.yaml | 188 +++ benchmarks/MP3S/LibriSpeech/LSTM/train.py | 344 ++++++ .../LibriSpeech/contextnet/hparam/ssl.yaml | 184 +++ .../MP3S/LibriSpeech/contextnet/train.py | 342 ++++++ .../MP3S/SLURP/LSTM_linear/hparams/ssl.yaml | 201 ++++ benchmarks/MP3S/SLURP/LSTM_linear/train.py | 318 +++++ .../VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml | 195 +++ benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py | 468 ++++++++ .../hparams/branchformer_summarymixing.yaml | 306 ----- .../hparams/branchformer_summarymixing.yaml | 277 ----- .../hparams/summarymixing_wav2vec2.yaml | 180 +++ .../wav2vec2/make_librilight_csv.py | 90 ++ .../wav2vec2/train_sb_wav2vec2_mel.py | 478 ++++++++ .../hparams/branchformer_summarymixing.yaml | 359 ------ speechbrain/lobes/models/VanillaNN.py | 8 +- .../lobes/models/transformer/Branchformer.py | 489 -------- .../lobes/models/transformer/Conformer.py | 136 ++- .../lobes/models/transformer/Transformer.py | 1044 ----------------- .../models/transformer/TransformerASR.py | 665 ----------- speechbrain/lobes/models/wav2vec.py | 569 +++++++++ speechbrain/nnet/summary_mixing.py | 276 ++++- 25 files changed, 4583 insertions(+), 3273 deletions(-) create mode 100644 benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml create mode 100644 benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml create mode 100644 benchmarks/MP3S/CommonVoice/LSTM/train.py create mode 100644 benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml create mode 100644 benchmarks/MP3S/LibriSpeech/LSTM/train.py create mode 100644 benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml create mode 100644 benchmarks/MP3S/LibriSpeech/contextnet/train.py create mode 100644 benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml create mode 100644 benchmarks/MP3S/SLURP/LSTM_linear/train.py create mode 100644 benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml create mode 100644 benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py delete mode 100644 recipes/AISHELL-1/ASR/transformer/hparams/branchformer_summarymixing.yaml delete mode 100644 recipes/CommonVoice/ASR/transformer/hparams/branchformer_summarymixing.yaml create mode 100644 recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml create mode 100644 recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py create mode 100644 recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py delete mode 100644 recipes/LibriSpeech/ASR/transformer/hparams/branchformer_summarymixing.yaml delete mode 100644 speechbrain/lobes/models/transformer/Branchformer.py delete mode 100644 speechbrain/lobes/models/transformer/Transformer.py delete mode 100755 speechbrain/lobes/models/transformer/TransformerASR.py create mode 100644 speechbrain/lobes/models/wav2vec.py diff --git a/README.md b/README.md index bce2557..7a54638 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,27 @@ -# SummaryMixing for SpeechBrain v1.0 -*Halve your VRAM requirements and train 30% faster any speech model achieving equivalents or better Word Error Rates and SLU accuracies with SummaryMixing Conformers and Branchformers.* - -## !! A word about using SummaryMixing with SpeechBrain V1.0 !! - -The main branch of this repository will keep tracking the latest version of SpeechBrain available. Unfortunately the results reported in our [publication](https://arxiv.org/abs/2307.07421) and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated -[branch](https://github.com/SamsungLabs/SummaryMixing/tree/speechbrain_v0.5) that contains the code compatible with SpeechBrain v0.5! +# SummaryMixing wav2vec 2.0 +We equip wav2vec 2.0 (w2v2) with SummaryMixing, our linear-time alternative to the quadratic cost self-attention. Compared to self-attention based w2v2, SummaryMixing based w2v2 greatly reduces the cost for self-supervised pre-training and gives better or the same level performance on downstream tasks. ## In brief -This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the [publication](https://arxiv.org/abs/2307.07421) for further details). The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) toolkit with version 0.5 -- copy and paste is all you need to start using SummaryMixing in your setup. If you wish to run with the latest version of SpeechBrain (v1.0+), please go to the main branch of this repository. SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear). +This repository implements SummaryMixing w2v2. The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) copy and paste is all you need to start using SummaryMixing in your setup. ## A glance at SummaryMixing SummaryMixing is a linear-time alternative to self-attention (SA) for speech processing models such as Transformers, Conformers or Branchformers. Instead of computing pair-wise scores between tokens (leading to quadratic-time complexity for SA), it summarises a whole utterance with mean over vectors for all time steps. SummaryMixing is based on the recent [findings](https://arxiv.org/pdf/2207.02971.pdf) demonstrating that self-attention could be useless for speech recognition as the attention weights of trained ASR systems are almost uniformly distributed accross the tokens composing a sequence. SummaryMixing also is a generalisation of the recent [HyperMixer](https://arxiv.org/abs/2203.03691) and [HyperConformer](https://arxiv.org/abs/2305.18281) to better and simpler mixing functions. In a SummaryMixing cell, that takes the same inputs and produces the same outputs than self-attention, contributions from each time step are first transformed and then averaged globally before being fed back to each time step. This is visible in Figure 1 in the [article](https://arxiv.org/abs/2307.07421). Therefore, the time-complexity is reduced to linear. +In this branch, we use SummaryMixing for self-supervised learning by equipping w2v2 with SummaryMixing. For a detailed description, please refer to this [article]() + ### A few results -A SummaryMixing-equipped Conformer outperforms a self-attention equivalent model on Librispeech test-clean (2.1% vs 2.3%) and test-other (5.1% vs 5.4%). This is done with a 30% training reduction as well as less than half of the memory budget (from 46GB to 21GB). Such gains are also visible with CommonVoice, AISHELL-1 and Tedlium2. This gain is also visible at decoding time as the real-time factor remains stable (does not increase) with the sentence length for a SummaryMixing Branchformer while the same model with self-attention would see its RTF following a quadratic increase. The SpeechBrain configuration files in this repository can reproduce these numbers. - -The following Table gives an idea of the results observed with Librispeech. More results on CommonVoice, AISHELL, Tedlium, SLURP, and Google Speech Command are available in the [article](https://arxiv.org/abs/2307.07421). -| Encoder | Variant | Dev-clean | Test-clean | Test-other | GPU | VRAM | -|------------------|----------------------|--------------------|---------------------|---------------------|----------------|---------------| -| | | **WER \%** | **WER \%** | **WER \%** | **hours** | **GB** | -| ContextNet | N.A. | 3.3 | 2.3 | 5.9 | 160 | 25 | -| Transformer | Self-attention | 3.3 | 2.3 | 5.5 | 129 | 40 | -| Conformer | Self-attention | 2.8 | 2.3 | 5.4 | 137 | 46 | -| Branchformer | Self-attention | 2.9 | 2.2 | 5.1 | 132 | 45 | -| | CNN Only | 3.1 | 2.4 | 5.7 | 83 | 22 | -| | HyperMixer | 3.1 | 2.3 | 5.6 | 126 | 30 | -| | FastFormer | 3.0 | 2.2 | 5.4 | 96 | 23 | -| | **Proposed** | -| Conformer | SummaryMixing | 2.8 | 2.1 | 5.1 | 98 | 21 | -| Branchformers | SummaryMixing-lite | 3.0 | 2.2 | 5.2 | 98 | 23 | -| | SummaryMixing | 2.9 | 2.2 | 5.1 | 105 | 26 | -| | +Summary Decoder | 3.1 | 2.3 | 5.3 | 104 | 26 | - - -RTF performance +In the experiment of the [article](), SummaryMixing-equipped w2v2 reduces the pre-training time and memory budget by 18% and 23%, respectively, with better or equivalent results for the downstream automatic speech recognition, intent classification, emotion recognition, and automatic speaker verification. The following Table gives the results of SummaryMixing-based and attention-based SSL models on CommonVoice Welsh ASR and SLURP intent classification. For the results of other downstream tasks please refer to the [article](). The SpeechBrain configuration files in this repository can reproduce these numbers. + + +| Context Encoder | Size | Pre-trained on | Welsh 15.8 WER | SLURP Intent Classification Acc. | +|------------------|----------------------|--------------------|---------------------| +| Self-attention | 166M | LibriLight 4.3k h | 50.8 | 78.1 | +| SummaryMixing | 155M | LibriLight 4.3k h | 48.3 | 80.5 | +|------------------|----------------------|--------------------|---------------------|---------------------| +| w2v2 base | 95M | LibriSpeech 960 h | 54.5 | 77.7 | +| w2v2 large | 317M | LibriLight 60k h | 45.4 | 79.0 | ## Citation diff --git a/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml new file mode 100644 index 0000000..29613a2 --- /dev/null +++ b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml @@ -0,0 +1,195 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on CommonVoice cy with a LSTM downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/CommonVoice/LSTM +# Copy this file under benchmarks/MP3S/CommonVoice/LSTM/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +language: cy # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english +output_folder: !ref results/CommonVoice// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +avoid_if_longer_than: 10.0 + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0004 +lr_weights: 0.02 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +token_type: bpe # ["unigram", "bpe", "char"] +character_coverage: 1.0 + + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref +dataloader_options: + batch_size: !ref + num_workers: 4 +test_dataloader_options: + batch_size: !ref + num_workers: 4 + + +valid_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 100 # BPE size, index(blank/eos/bos) = 0 + + +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: 2 + bidirectional: True + dropout: 0.2 + hidden_size: 1024 + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml new file mode 100644 index 0000000..faca8f7 --- /dev/null +++ b/benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_eu.yaml @@ -0,0 +1,195 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on CommonVoice eu with a LSTM downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/CommonVoice/LSTM +# Copy this file under benchmarks/MP3S/CommonVoice/LSTM/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +language: eu # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english +output_folder: !ref results/CommonVoice// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/ +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +avoid_if_longer_than: 10.0 + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0005 +lr_weights: 0.025 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +token_type: bpe # ["unigram", "bpe", "char"] +character_coverage: 1.0 + + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref +dataloader_options: + batch_size: !ref + num_workers: 4 +test_dataloader_options: + batch_size: !ref + num_workers: 4 + + +valid_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 100 # BPE size, index(blank/eos/bos) = 0 + + +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: 2 + bidirectional: True + dropout: 0.2 + hidden_size: 1024 + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/CommonVoice/LSTM/train.py b/benchmarks/MP3S/CommonVoice/LSTM/train.py new file mode 100644 index 0000000..ddb3a39 --- /dev/null +++ b/benchmarks/MP3S/CommonVoice/LSTM/train.py @@ -0,0 +1,305 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Script for training an ASR model evaluating an SSL representation +model on one language from the CommonVoice dataset. A SentencePiece tokenizer +with number of tokens equal to is learned in a first phase +on the considered language. +""" + +import sys +import torch +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +import torchaudio +from speechbrain.tokenizers.SentencePiece import SentencePiece +from speechbrain.utils.data_utils import undo_padding + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + # Forward pass + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + y = self.modules.enc(feats) + y = y[0] # As it is an RNN output + # Compute outputs + p_tokens = None + logits = self.modules.ctc_lin(y) + p_ctc = self.hparams.log_softmax(logits) + if stage != sb.Stage.TRAIN: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + loss = loss_ctc + + if stage != sb.Stage.TRAIN: + # Decode token terms to words + predicted_words = self.tokenizer( + predicted_tokens, task="decode_from_list" + ) + + # Convert indices to words + target_words = undo_padding(tokens, tokens_lens) + target_words = self.tokenizer(target_words, task="decode_from_list") + + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + predictions = self.compute_forward(batch, stage=stage) + with torch.no_grad(): + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss.detach() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + stage_stats["loss"] + ) + old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( + stage_stats["loss"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr_model + ) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_weights + ) + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +# Define custom data procedure +def dataio_prepare(hparams, tokenizer): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + + # 1. Define datasets + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted( + sort_key="duration", + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_options"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", + reverse=True, + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_options"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + # We also sort the validation data so it is faster to validate + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], replacements={"data_root": data_folder}, + ) + + # We also sort the validation data so it is faster to validate + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + info = torchaudio.info(wav) + sig = sb.dataio.dataio.read_audio(wav) + resampled = torchaudio.transforms.Resample( + info.sample_rate, hparams["sample_rate"], + )(sig) + return resampled + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides("tokens_list", "tokens") + def text_pipeline(wrd): + tokens_list = tokenizer.sp.encode_as_ids(wrd) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "tokens"], + ) + return train_data, valid_data, test_data + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset preparation + from common_voice_prepare import prepare_common_voice # noqa + + # multi-gpu (ddp) save data preparation + # Due to DDP, we do the preparation ONLY on the main python process + run_on_main( + prepare_common_voice, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "train_tsv_file": hparams["train_tsv_file"], + "dev_tsv_file": hparams["dev_tsv_file"], + "test_tsv_file": hparams["test_tsv_file"], + "accented_letters": hparams["accented_letters"], + "language": hparams["language"], + "skip_prep": hparams["skip_prep"], + }, + ) + + # Defining tokenizer and loading it + tokenizer = SentencePiece( + model_dir=hparams["save_folder"], + vocab_size=hparams["output_neurons"], + annotation_train=hparams["train_csv"], + annotation_read="wrd", + model_type=hparams["token_type"], + character_coverage=hparams["character_coverage"], + ) + + # Create the datasets objects as well as tokenization and encoding :-D + train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + asr_brain.tokenizer = tokenizer + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Testing + asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt" + asr_brain.evaluate( + test_data, + min_key="WER", + test_loader_kwargs=hparams["test_dataloader_options"], + ) diff --git a/benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml b/benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml new file mode 100644 index 0000000..86f3ee2 --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml @@ -0,0 +1,188 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on LibriSpeech with a LSTM downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/LibriSpeech/LSTM +# Copy this file under benchmarks/MP3S/LibriSpeech/LSTM/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/LibriSpeech/ +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech +# noise/ris dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100"] +dev_splits: ["dev-clean"] +test_splits: ["test-clean", "test-other"] + +skip_prep: False +ckpt_interval_minutes: 25 # save checkpoint every N min +train_csv: !ref /train-clean-100.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0002 +lr_weights: 0.01 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +language_modelling: False +#ngram_lm_path: !PLACEHOLDER #path/to/4-gram.arpa + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 30 # BPE size, index(blank/eos/bos) = 0 + +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.nnet.RNN.LSTM + input_shape: [Null, Null, !ref ] + num_layers: 2 + bidirectional: True + dropout: 0.2 + hidden_size: 1024 + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 2048 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/LibriSpeech/LSTM/train.py b/benchmarks/MP3S/LibriSpeech/LSTM/train.py new file mode 100644 index 0000000..d763448 --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/LSTM/train.py @@ -0,0 +1,344 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for training an SSL-based ctc ASR system with librispeech. + Decoding is performed with ctc greedy or LM-rescored decoder. +""" + +import os +import sys +import torch +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + # Forward pass + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + y = self.modules.enc(feats) + y = y[0] # As it is an RNN output + # Compute outputs + p_tokens = None + logits = self.modules.ctc_lin(y) + p_ctc = self.hparams.log_softmax(logits) + if stage != sb.Stage.TRAIN: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + loss = loss_ctc + + if stage == sb.Stage.VALID: + # Decode token terms to words + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + elif stage == sb.Stage.TEST: + if self.hparams.language_modelling: + predicted_words = [] + for logs in p_ctc: + text = decoder.decode(logs.detach().cpu().numpy()) + predicted_words.append(text.split(" ")) + else: + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + predictions = self.compute_forward(batch, stage=stage) + with torch.no_grad(): + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss.detach() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + stage_stats["loss"] + ) + old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( + stage_stats["loss"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr_model + ) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_weights + ) + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + # test is separate + test_datasets = {} + for csv_file in hparams["test_csv"]: + name = Path(csv_file).stem + test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_file, replacements={"data_root": data_folder} + ) + test_datasets[name] = test_datasets[name].filtered_sorted( + sort_key="duration" + ) + + datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + label_encoder = sb.dataio.encoder.CTCTextEncoder() + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides( + "wrd", "char_list", "tokens_list", "tokens" + ) + def text_pipeline(wrd): + yield wrd + char_list = list(wrd) + yield char_list + tokens_list = label_encoder.encode_sequence(char_list) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + special_labels = { + "blank_label": hparams["blank_index"], + "unk_label": hparams["unk_index"], + } + label_encoder.load_or_create( + path=lab_enc_file, + from_didatasets=[train_data], + output_key="char_list", + special_labels=special_labels, + sequence_input=True, + ) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "char_list", "tokens"], + ) + return train_data, valid_data, test_datasets, label_encoder + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from librispeech_prepare import prepare_librispeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_librispeech, + kwargs={ + "data_folder": hparams["data_folder"], + "tr_splits": hparams["train_splits"], + "dev_splits": hparams["dev_splits"], + "te_splits": hparams["test_splits"], + "save_folder": hparams["output_folder"], + "merge_lst": hparams["train_splits"], + "merge_name": "train.csv", + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + train_data, valid_data, test_datasets, label_encoder = dataio_prepare( + hparams + ) + # Loading the labels for the LM decoding and the CTC decoder + if "language_modelling" in hparams: + + if hparams["language_modelling"]: + from pyctcdecode import build_ctcdecoder + + ind2lab = label_encoder.ind2lab + labels = [ind2lab[x] for x in range(len(ind2lab))] + labels = [""] + labels[ + 1: + ] # Replace the token with a blank character, needed for PyCTCdecode + decoder = build_ctcdecoder( + labels, + kenlm_model_path=hparams[ + "ngram_lm_path" + ], # either .arpa or .bin file + alpha=0.5, # tuned on a val set + beta=1.0, # tuned on a val set + ) + else: + hparams["language_modelling"] = False + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Loading the SSL model + # We dynamicaly add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + asr_brain.tokenizer = label_encoder + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Testing + for k in test_datasets.keys(): # keys are test_clean, test_other etc + asr_brain.hparams.wer_file = os.path.join( + hparams["output_folder"], "wer_{}.txt".format(k) + ) + asr_brain.evaluate( + test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"] + ) diff --git a/benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml b/benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml new file mode 100644 index 0000000..d2bce58 --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/contextnet/hparam/ssl.yaml @@ -0,0 +1,184 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on LibriSpeech with a ContextNet downstream model + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/LibriSpeech/contextnet +# Copy this file under benchmarks/MP3S/LibriSpeech/contextnet/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/LibriSpeech/ +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech +# noise/ris dataset will automatically be downloaded +# data_folder_rirs: !ref +train_splits: ["train-clean-100"] +dev_splits: ["dev-clean"] +test_splits: ["test-clean", "test-other"] +skip_prep: False +ckpt_interval_minutes: 25 # save checkpoint every N min +train_csv: !ref /train-clean-100.csv +valid_csv: !ref /dev-clean.csv +test_csv: + - !ref /test-clean.csv + - !ref /test-other.csv + +num_layers_ssl: 13 +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +lr: 0.0002 +lr_weights: 0.01 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 +language_modelling: False +#ngram_lm_path: !PLACEHOLDER # path/to/4-gram.arpa + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +# Must be 3 per GPU to fit 32GB of VRAM +batch_size: 4 +test_batch_size: 4 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +# Model parameters +activation: !name:torch.nn.Sigmoid +dnn_layers: 1 +dnn_neurons: 768 +freeze_encoder: True + +# Outputs +output_neurons: 30 +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +enc: !new:speechbrain.lobes.models.ContextNet.ContextNet + input_shape: [null, null, !ref ] + strides: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + +# only unitary strides to keep the frame rate + + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: 640 + n_neurons: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + +modules: + enc: !ref + ctc_lin: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + tokenizer: !ref + +blank_index: 0 +unk_index: 1 + + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/LibriSpeech/contextnet/train.py b/benchmarks/MP3S/LibriSpeech/contextnet/train.py new file mode 100644 index 0000000..9d0359f --- /dev/null +++ b/benchmarks/MP3S/LibriSpeech/contextnet/train.py @@ -0,0 +1,342 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for training an SSL-based ctc ASR system with librispeech. + Decoding is performed with ctc greedy or LM-rescored decoder. +""" + +import os +import sys +import torch +import logging +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + # Forward pass + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + y = self.modules.enc(feats) + # Compute outputs + p_tokens = None + logits = self.modules.ctc_lin(y) + p_ctc = self.hparams.log_softmax(logits) + if stage != sb.Stage.TRAIN: + p_tokens = sb.decoders.ctc_greedy_decode( + p_ctc, wav_lens, blank_id=self.hparams.blank_index + ) + return p_ctc, wav_lens, p_tokens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + p_ctc, wav_lens, predicted_tokens = predictions + ids = batch.id + tokens, tokens_lens = batch.tokens + loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + loss = loss_ctc + + if stage == sb.Stage.VALID: + # Decode token terms to words + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + elif stage == sb.Stage.TEST: + if self.hparams.language_modelling: + predicted_words = [] + for logs in p_ctc: + text = decoder.decode(logs.detach().cpu().numpy()) + predicted_words.append(text.split(" ")) + else: + predicted_words = [ + "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ") + for utt_seq in predicted_tokens + ] + + target_words = [wrd.split(" ") for wrd in batch.wrd] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + predictions = self.compute_forward(batch, stage=stage) + with torch.no_grad(): + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss.detach() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr_model, new_lr_model = self.hparams.lr_annealing_model( + stage_stats["loss"] + ) + old_lr_weights, new_lr_weights = self.hparams.lr_annealing_weights( + stage_stats["loss"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr_model + ) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_weights + ) + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_model": old_lr_model}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + # test is separate + test_datasets = {} + for csv_file in hparams["test_csv"]: + name = Path(csv_file).stem + test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_file, replacements={"data_root": data_folder} + ) + test_datasets[name] = test_datasets[name].filtered_sorted( + sort_key="duration" + ) + + datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + label_encoder = sb.dataio.encoder.CTCTextEncoder() + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides( + "wrd", "char_list", "tokens_list", "tokens" + ) + def text_pipeline(wrd): + yield wrd + char_list = list(wrd) + yield char_list + tokens_list = label_encoder.encode_sequence(char_list) + yield tokens_list + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + special_labels = { + "blank_label": hparams["blank_index"], + "unk_label": hparams["unk_index"], + } + label_encoder.load_or_create( + path=lab_enc_file, + from_didatasets=[train_data], + output_key="char_list", + special_labels=special_labels, + sequence_input=True, + ) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "wrd", "char_list", "tokens"], + ) + return train_data, valid_data, test_datasets, label_encoder + + +if __name__ == "__main__": + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from librispeech_prepare import prepare_librispeech # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_librispeech, + kwargs={ + "data_folder": hparams["data_folder"], + "tr_splits": hparams["train_splits"], + "dev_splits": hparams["dev_splits"], + "te_splits": hparams["test_splits"], + "save_folder": hparams["output_folder"], + "merge_lst": hparams["train_splits"], + "merge_name": "train.csv", + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + train_data, valid_data, test_datasets, label_encoder = dataio_prepare( + hparams + ) + # Loading the labels for the LM decoding and the CTC decoder + if "language_modelling" in hparams: + + if hparams["language_modelling"]: + from pyctcdecode import build_ctcdecoder + + ind2lab = label_encoder.ind2lab + labels = [ind2lab[x] for x in range(len(ind2lab))] + labels = [""] + labels[ + 1: + ] # Replace the token with a blank character, needed for PyCTCdecode + decoder = build_ctcdecoder( + labels, + kenlm_model_path=hparams[ + "ngram_lm_path" + ], # either .arpa or .bin file + alpha=0.5, # tuned on a val set + beta=1.0, # tuned on a val set + ) + else: + hparams["language_modelling"] = False + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # We dynamicaly add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for the LM!! + asr_brain.tokenizer = label_encoder + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Testing + for k in test_datasets.keys(): # keys are test_clean, test_other etc + asr_brain.hparams.wer_file = os.path.join( + hparams["output_folder"], "wer_{}.txt".format(k) + ) + asr_brain.evaluate( + test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"] + ) diff --git a/benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml b/benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml new file mode 100644 index 0000000..e7f5cb4 --- /dev/null +++ b/benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml @@ -0,0 +1,201 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream Intent classification on SLURP + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/SLURP/LSTM_linear +# Copy this file under benchmarks/MP3S/SLURP/LSTM_linear/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/SLURP/ +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +# The SLURP dataset will be automatically downloaded in the specified folder +data_folder: !PLACEHOLDER +# data_folder_rirs: !ref +train_splits: ["train_real"] +csv_train: !ref /train-type=direct.csv +csv_valid: !ref /devel-type=direct.csv +csv_test: !ref /test-type=direct.csv +skip_prep: False + +compute_cost: !name:speechbrain.nnet.losses.nll_loss + + +num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large ) +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 +encoder_dim: 768 + +# Training parameters +number_of_epochs: 20 +batch_size: 2 +test_batch_size: 2 +lr: 0.0002 +lr_weights: 0.01 +# token_type: unigram # ["unigram", "bpe", "char"] +sorting: random +ckpt_interval_minutes: 5 # save checkpoint every N min + +# Model parameters +sample_rate: 16000 +emb_size: 128 +dec_neurons: 512 +output_neurons: 18 # index(eos/bos) = 0 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: 2 # 2 on linux but 0 works on windows + drop_last: False + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +enc: !new:speechbrain.nnet.containers.Sequential + input_shape: [null, null, !ref ] + lstm: !new:speechbrain.nnet.RNN.LSTM + input_size: !ref + bidirectional: True + hidden_size: !ref + num_layers: 2 + linear: !new:speechbrain.nnet.linear.Linear + input_size: !ref * 2 + n_neurons: !ref + +# Decoding parameters +bos_index: 0 +eos_index: 0 +min_decode_ratio: 0.0 +max_decode_ratio: 10.0 +slu_beam_size: 80 +eos_threshold: 1.5 +temperature: 1.25 + +dataloader_opts: + batch_size: !ref + shuffle: True + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + + +# Models +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + +avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling + return_std: False + +output_mlp: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: 18 + bias: False + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +modules: + enc: !ref + avg_pool: !ref + output_mlp: !ref + weighted_ssl_model: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + +tokenizer: !new:sentencepiece.SentencePieceProcessor + +error_stats: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +model_opt_class: !name:torch.optim.Adam + lr: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.8 + patient: 0 + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + ssl_model: !ref + scheduler_model: !ref + scheduler_encoder: !ref + counter: !ref + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +seq_cost: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.1 + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/benchmarks/MP3S/SLURP/LSTM_linear/train.py b/benchmarks/MP3S/SLURP/LSTM_linear/train.py new file mode 100644 index 0000000..307f648 --- /dev/null +++ b/benchmarks/MP3S/SLURP/LSTM_linear/train.py @@ -0,0 +1,318 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for "direct" (speech -> scenario) "Intent" classification using SLURP Dataset. +18 Scenarios classes are present in SLURP (calendar, email) +We encode input waveforms into features using a SSL encoder. +The probing is done using a RNN layer followed by a linear classifier. +""" + + +import os +import sys +from hyperpyyaml import load_hyperpyyaml +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main + + +class IntentIdBrain(sb.Brain): + def compute_forward(self, batch, stage): + """Computation pipeline based on a encoder + emotion classifier.""" + + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + + feats = self.modules.weighted_ssl_model(wavs, wav_lens) + + # last dim will be used for AdaptativeAVG pool + outputs = self.modules.enc(feats) + outputs = self.hparams.avg_pool(outputs, wav_lens) + outputs = outputs.view(outputs.shape[0], -1) + outputs = self.modules.output_mlp(outputs) + + outputs = self.hparams.log_softmax(outputs) + return outputs + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss using speaker-id as label.""" + scenario_id, _ = batch.scenario_encoded + scenario_id = scenario_id.squeeze(1) + loss = self.hparams.compute_cost(predictions, scenario_id) + if stage != sb.Stage.TRAIN: + self.error_metrics.append(batch.id, predictions, scenario_id) + + return loss + + def fit_batch(self, batch): + """Trains the parameters given a single batch in input""" + + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + + return loss.detach() + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch. + Arguments + --------- + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + epoch : int + The currently-starting epoch. This is passed + `None` during the test stage. + """ + + # Set up statistics trackers for this stage + self.loss_metric = sb.utils.metric_stats.MetricStats( + metric=sb.nnet.losses.nll_loss + ) + + # Set up evaluation-only statistics trackers + if stage != sb.Stage.TRAIN: + self.error_metrics = self.hparams.error_stats() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of an epoch. + Arguments + --------- + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST + stage_loss : float + The average loss for all of the data processed in this stage. + epoch : int + The currently-starting epoch. This is passed + `None` during the test stage. + """ + + # Store the train loss until the validation stage. + if stage == sb.Stage.TRAIN: + self.train_loss = stage_loss + + # Summarize the statistics from the stage for record-keeping. + else: + stats = { + "loss": stage_loss, + "error_rate": self.error_metrics.summarize("average"), + } + + # At the end of validation... + if stage == sb.Stage.VALID: + old_lr, new_lr = self.hparams.lr_annealing_model( + stats["error_rate"] + ) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr + ) + + ( + old_lr_encoder, + new_lr_encoder, + ) = self.hparams.lr_annealing_weights(stats["error_rate"]) + sb.nnet.schedulers.update_learning_rate( + self.weights_optimizer, new_lr_encoder + ) + + # The train_logger writes a summary to stdout and to the logfile. + self.hparams.train_logger.log_stats( + {"Epoch": epoch, "lr": old_lr, "wave2vec_lr": old_lr_encoder}, + train_stats={"loss": self.train_loss}, + valid_stats=stats, + ) + + # Save the current checkpoint and delete previous checkpoints, + self.checkpointer.save_and_keep_only( + meta=stats, min_keys=["error_rate"] + ) + + # We also write statistics about test data to stdout and to logfile. + if stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + {"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stats, + ) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prep(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined + functions. We expect `prepare_mini_librispeech` to have been called before + this, so that the `train.json`, `valid.json`, and `valid.json` manifest + files are available. + Arguments + --------- + hparams : dict + This dictionary is loaded from the `train.yaml` file, and it includes + all the hyperparameters needed for dataset construction and loading. + Returns + ------- + datasets : dict + Contains two keys, "train" and "valid" that correspond + to the appropriate DynamicItemDataset object. + """ + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["csv_train"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["csv_valid"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["csv_test"], replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Define audio pipeline + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + """Load the signal, and pass it and its length to the corruption class. + This is done on the CPU in the `collate_fn`.""" + sig = sb.dataio.dataio.read_audio(wav) + return sig + + # Initialization of the label encoder. The label encoder assignes to each + # of the observed label a unique index (e.g, 'spk01': 0, 'spk02': 1, ..) + label_encoder = sb.dataio.encoder.CategoricalEncoder() + + # Define label pipeline: + @sb.utils.data_pipeline.takes("semantics") + @sb.utils.data_pipeline.provides("scenario", "scenario_encoded") + def label_pipeline(semantics): + scenario = semantics.split("'")[3] + yield scenario + scenario_encoded = label_encoder.encode_label_torch(scenario) + yield scenario_encoded + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + sb.dataio.dataset.add_dynamic_item(datasets, label_pipeline) + # Define datasets. We also connect the dataset with the data processing + # functions defined above. + sb.dataio.dataset.set_output_keys( + datasets, ["id", "sig", "scenario", "scenario_encoded"], + ) + # Load or compute the label encoder (with multi-GPU DDP support) + # Please, take a look into the lab_enc_file to see the label to index + # mappinng. + + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + label_encoder.load_or_create( + path=lab_enc_file, from_didatasets=[datasets[0]], output_key="scenario", + ) + + return {"train": datasets[0], "valid": datasets[1], "test": datasets[2]} + + +# RECIPE BEGINS! +if __name__ == "__main__": + # Reading command line arguments. + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Initialize ddp (useful only for multi-GPU DDP training). + sb.utils.distributed.ddp_init_group(run_opts) + + # Load hyperparameters file with command-line overrides. + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + from prepare import prepare_SLURP # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_SLURP, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["output_folder"], + "train_splits": hparams["train_splits"], + "slu_type": "direct", + "skip_prep": hparams["skip_prep"], + }, + ) + + # Data preparation, to be run on only one process. + # Create dataset objects "train", "valid", and "test". + datasets = dataio_prep(hparams) + + # freeze the feature extractor part when unfreezing + + # Initialize the Brain object to prepare for mask training. + ic_id_brain = IntentIdBrain( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # The `fit()` method iterates the training loop, calling the methods + # necessary to update the parameters of the model. Since all objects + # with changing state are managed by the Checkpointer, training can be + # stopped at any point, and will be resumed on next call. + ic_id_brain.fit( + epoch_counter=ic_id_brain.hparams.epoch_counter, + train_set=datasets["train"], + valid_set=datasets["valid"], + train_loader_kwargs=hparams["train_dataloader_opts"], + valid_loader_kwargs=hparams["valid_dataloader_opts"], + ) + + # Load the best checkpoint for evaluation + test_stats = ic_id_brain.evaluate( + test_set=datasets["test"], + min_key="error_rate", + test_loader_kwargs=hparams["test_dataloader_opts"], + ) diff --git a/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml new file mode 100644 index 0000000..d9456f4 --- /dev/null +++ b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml @@ -0,0 +1,195 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is configurations of SummaryMixing wav2vec 2.0 downstream ASV on VoxCeleb + +# Usage: Install SpeechBrain MP3S +# Create a folder benchmarks/MP3S/VoxCeleb1/ecapa_tdnn +# Copy this file under benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# Basic parameters +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/VoxCeleb1/ +save_folder: !ref /save +train_log: !ref /train_log.txt +# Data files +data_folder: !PLACEHOLDER # e.g. /path/to/Voxceleb +train_annotation: !ref /train.csv +valid_annotation: !ref /dev.csv +num_layers_ssl: 13 +pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2 + +verification_file: !PLACEHOLDER #path/to/veri_test2.txt + +skip_prep: False +ckpt_interval_minutes: 15 # save checkpoint every N min +pretrain: True +# Training parameters +number_of_epochs: 15 +batch_size: 8 +lr: 0.001 +lr_final: 0.0001 +mask_length: 10 +mask_prob: 0.65 +lr_weights: 0.01 +sample_rate: 16000 +shuffle: True +random_chunk: True +sentence_len: 3 + +# Feature parameters + +encoder_dim: 768 + +# Number of speakers +out_n_neurons: 1211 #1211 for vox1 # 5994 for vox2, 7205 for vox1+vox2 + +freeze_wav2vec: True +dataloader_options: + batch_size: !ref + shuffle: !ref + drop_last: True + +embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN + input_size: !ref + channels: [512, 512, 512, 512, 1536] + kernel_sizes: [5, 3, 3, 3, 1] + dilations: [1, 2, 3, 4, 1] + groups: [1, 1, 1, 1, 1] + attention_channels: 64 + lin_neurons: 512 + +classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier + input_size: 512 + out_neurons: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: 768 + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: 0.0 + attention_type: SummaryMixing + local_proj_hid_dim: [768] + local_proj_out_dim: 768 + summary_hid_dim: [768] + mode: SummaryMixing + output_hidden_states: True + +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] +extractor_dim: 512 +embedding_dim: 768 + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [512, 512] + input_dim: 80 + +weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel + pretrained_path: !ref + num_layers: 13 + latent_encoder: !ref + CNN: !ref + in_dim: 512 + embedding_dim: 768 + dropout_encoder_input: 0.1 + output_hidden_states: True + + +modules: + weighted_ssl_model: !ref + embedding_model: !ref + classifier: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref ] + + +# Cost + optimization +compute_error: !name:speechbrain.nnet.losses.classification_error + +compute_cost: !new:speechbrain.nnet.losses.LogSoftmaxWrapper + loss_fn: !new:speechbrain.nnet.losses.AdditiveAngularMargin + margin: 0.2 + scale: 30 + + +model_opt_class: !name:torch.optim.Adam + lr: !ref + weight_decay: 0.000002 + +lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler + initial_value: !ref + final_value: !ref + epoch_count: !ref + +weights_opt_class: !name:torch.optim.Adam + lr: !ref + +lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +# Used to load the SB checkpoint of w2v2 +pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer + collect_in: !ref + + +# Logging + checkpoints +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_stats: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + embedding_model: !ref + scheduler_wav2vec: !ref + ssl_model: !ref + classifier: !ref + counter: !ref + +train_data: !ref /train.csv +enrol_data: !ref /enrol.csv +test_data: !ref /test.csv + +verif_batch_size: 2 +n_train_snts: 300000 # used for normalization stats + + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + +enrol_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + + +mean_var_norm_emb: !new:speechbrain.processing.features.InputNormalization + norm_type: global + std_norm: False diff --git a/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py new file mode 100644 index 0000000..5ba9a03 --- /dev/null +++ b/benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py @@ -0,0 +1,468 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Recipe for training then testing speaker embeddings using the VoxCeleb Dataset. +The embeddings are learned using the ECAPA-TDNN architecture +""" + +import os +from tqdm import tqdm +import sys +import logging +import random +import torch +import torchaudio +import speechbrain as sb +from speechbrain.utils.data_utils import download_file +from hyperpyyaml import load_hyperpyyaml +from speechbrain.utils.metric_stats import EER, minDCF +from speechbrain.utils.distributed import run_on_main + + +def compute_embedding(wavs, wav_lens): + """Compute speaker embeddings. + + Arguments + --------- + wavs : Torch.Tensor + Tensor containing the speech waveform (batch, time). + Make sure the sample rate is fs=16000 Hz. + wav_lens: Torch.Tensor + Tensor containing the relative length for each sentence + in the length (e.g., [0.8 0.6 1.0]) + """ + with torch.no_grad(): + wavs, wav_lens = ( + wavs.to(speaker_brain.device), + wav_lens.to(speaker_brain.device), + ) + feats = speaker_brain.modules.weighted_ssl_model(wavs, wav_lens) + embeddings = speaker_brain.modules.embedding_model(feats, wav_lens) + return embeddings.squeeze(1) + + +def compute_embedding_loop(data_loader): + """Computes the embeddings of all the waveforms specified in the + dataloader. + """ + embedding_dict = {} + + with torch.no_grad(): + for batch in tqdm(data_loader, dynamic_ncols=True): + batch = batch.to(hparams["device"]) + seg_ids = batch.id + wavs, lens = batch.sig + + found = False + for seg_id in seg_ids: + if seg_id not in embedding_dict: + found = True + if not found: + continue + wavs, lens = wavs.to(hparams["device"]), lens.to(hparams["device"]) + emb = compute_embedding(wavs, lens).unsqueeze(1) + for i, seg_id in enumerate(seg_ids): + embedding_dict[seg_id] = emb[i].detach().clone() + return embedding_dict + + +def get_verification_scores(veri_test): + """ Computes positive and negative scores given the verification split. + """ + scores = [] + positive_scores = [] + negative_scores = [] + + save_file = os.path.join(hparams["output_folder"], "scores.txt") + s_file = open(save_file, "w") + + # Cosine similarity initialization + similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + # creating cohort for score normalization + if "score_norm" in hparams: + train_cohort = torch.stack(list(train_dict.values())) + + for i, line in enumerate(veri_test): + + # Reading verification file (enrol_file test_file label) + lab_pair = int(line.split(" ")[0].rstrip().split(".")[0].strip()) + enrol_id = line.split(" ")[1].rstrip().split(".")[0].strip() + test_id = line.split(" ")[2].rstrip().split(".")[0].strip() + enrol = enrol_dict[enrol_id] + test = test_dict[test_id] + + if "score_norm" in hparams: + # Getting norm stats for enrol impostors + enrol_rep = enrol.repeat(train_cohort.shape[0], 1, 1) + score_e_c = similarity(enrol_rep, train_cohort) + + if "cohort_size" in hparams: + score_e_c = torch.topk( + score_e_c, k=hparams["cohort_size"], dim=0 + )[0] + + mean_e_c = torch.mean(score_e_c, dim=0) + std_e_c = torch.std(score_e_c, dim=0) + + # Getting norm stats for test impostors + test_rep = test.repeat(train_cohort.shape[0], 1, 1) + score_t_c = similarity(test_rep, train_cohort) + + if "cohort_size" in hparams: + score_t_c = torch.topk( + score_t_c, k=hparams["cohort_size"], dim=0 + )[0] + + mean_t_c = torch.mean(score_t_c, dim=0) + std_t_c = torch.std(score_t_c, dim=0) + + # Compute the score for the given sentence + score = similarity(enrol, test)[0] + + # Perform score normalization + if "score_norm" in hparams: + if hparams["score_norm"] == "z-norm": + score = (score - mean_e_c) / std_e_c + elif hparams["score_norm"] == "t-norm": + score = (score - mean_t_c) / std_t_c + elif hparams["score_norm"] == "s-norm": + score_e = (score - mean_e_c) / std_e_c + score_t = (score - mean_t_c) / std_t_c + score = 0.5 * (score_e + score_t) + + # write score file + s_file.write("%s %s %i %f\n" % (enrol_id, test_id, lab_pair, score)) + scores.append(score) + + if lab_pair == 1: + positive_scores.append(score) + else: + negative_scores.append(score) + + s_file.close() + return positive_scores, negative_scores + + +def dataio_prep_verif(params): + "Creates the dataloaders and their data processing pipelines." + + data_folder = params["data_folder"] + + # 1. Declarations: + + # Train data (used for normalization) + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=params["train_data"], replacements={"data_root": data_folder}, + ) + train_data = train_data.filtered_sorted( + sort_key="duration", select_n=params["n_train_snts"] + ) + + # Enrol data + enrol_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=params["enrol_data"], replacements={"data_root": data_folder}, + ) + enrol_data = enrol_data.filtered_sorted(sort_key="duration") + + # Test data + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=params["test_data"], replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, enrol_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav", "start", "stop") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav, start, stop): + start = int(start) + stop = int(stop) + num_frames = stop - start + sig, fs = torchaudio.load( + wav, num_frames=num_frames, frame_offset=start + ) + sig = sig.transpose(0, 1).squeeze(1) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Set output: + sb.dataio.dataset.set_output_keys(datasets, ["id", "sig"]) + + # 4 Create dataloaders + train_dataloader = sb.dataio.dataloader.make_dataloader( + train_data, **params["train_dataloader_opts"] + ) + enrol_dataloader = sb.dataio.dataloader.make_dataloader( + enrol_data, **params["enrol_dataloader_opts"] + ) + test_dataloader = sb.dataio.dataloader.make_dataloader( + test_data, **params["test_dataloader_opts"] + ) + + return train_dataloader, enrol_dataloader, test_dataloader + + +class SpeakerBrain(sb.core.Brain): + """Class for speaker embedding training" + """ + + def compute_forward(self, batch, stage): + """Computation pipeline based on a encoder + speaker classifier. + Data augmentation and environmental corruption are applied to the + input speech. + """ + batch = batch.to(self.device) + wavs, lens = batch.sig + feats = self.modules.weighted_ssl_model(wavs, lens) + # Embeddings + speaker classifier + embeddings = self.modules.embedding_model(feats) + outputs = self.modules.classifier(embeddings) + return outputs, lens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss using speaker-id as label. + """ + predictions, lens = predictions + uttid = batch.id + spkid, _ = batch.spk_id_encoded + + loss = self.hparams.compute_cost(predictions, spkid, lens) + + if stage == sb.Stage.TRAIN and hasattr( + self.hparams.lr_annealing, "on_batch_end" + ): + self.hparams.lr_annealing.on_batch_end(self.model_optimizer) + + if stage != sb.Stage.TRAIN: + self.error_metrics.append(uttid, predictions, spkid, lens) + + return loss + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + loss.backward() + self.model_optimizer.step() + self.weights_optimizer.step() + + self.model_optimizer.zero_grad() + self.weights_optimizer.zero_grad() + return loss.detach() + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of an epoch.""" + if stage != sb.Stage.TRAIN: + self.error_metrics = self.hparams.error_stats() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["ErrorRate"] = self.error_metrics.summarize("average") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + old_lr, new_lr = self.hparams.lr_annealing(epoch) + sb.nnet.schedulers.update_learning_rate( + self.model_optimizer, new_lr + ) + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr": old_lr}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"ErrorRate": stage_stats["ErrorRate"]}, + min_keys=["ErrorRate"], + ) + + def init_optimizers(self): + "Initializes the weights optimizer and model optimizer" + self.weights_optimizer = self.hparams.weights_opt_class( + [self.modules.weighted_ssl_model.weights] + ) + self.model_optimizer = self.hparams.model_opt_class( + self.hparams.model.parameters() + ) + # Initializing the weights + if self.checkpointer is not None: + self.checkpointer.add_recoverable("modelopt", self.model_optimizer) + self.checkpointer.add_recoverable( + "weights_opt", self.weights_optimizer + ) + + +def dataio_prep(hparams): + "Creates the datasets and their data processing pipelines." + + data_folder = hparams["data_folder"] + + # 1. Declarations: + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_annotation"], + replacements={"data_root": data_folder}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_annotation"], + replacements={"data_root": data_folder}, + ) + + datasets = [train_data, valid_data] + label_encoder = sb.dataio.encoder.CategoricalEncoder() + + snt_len_sample = int(hparams["sample_rate"] * hparams["sentence_len"]) + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav", "start", "stop", "duration") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav, start, stop, duration): + if hparams["random_chunk"]: + duration_sample = int(duration * hparams["sample_rate"]) + start = random.randint(0, duration_sample - snt_len_sample) + stop = start + snt_len_sample + else: + start = int(start) + stop = int(stop) + num_frames = stop - start + sig, fs = torchaudio.load( + wav, num_frames=num_frames, frame_offset=start + ) + sig = sig.transpose(0, 1).squeeze(1) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("spk_id") + @sb.utils.data_pipeline.provides("spk_id", "spk_id_encoded") + def label_pipeline(spk_id): + yield spk_id + spk_id_encoded = label_encoder.encode_sequence_torch([spk_id]) + yield spk_id_encoded + + sb.dataio.dataset.add_dynamic_item(datasets, label_pipeline) + + # 3. Fit encoder: + # Load or compute the label encoder (with multi-GPU DDP support) + lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") + label_encoder.load_or_create( + path=lab_enc_file, from_didatasets=[train_data], output_key="spk_id", + ) + + # 4. Set output: + sb.dataio.dataset.set_output_keys(datasets, ["id", "sig", "spk_id_encoded"]) + + return train_data, valid_data, label_encoder + + +if __name__ == "__main__": + + logger = logging.getLogger(__name__) + # This flag enables the inbuilt cudnn auto-tuner + torch.backends.cudnn.benchmark = True + + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Initialize ddp (useful only for multi-GPU DDP training) + sb.utils.distributed.ddp_init_group(run_opts) + + # Load hyperparameters file with command-line overrides + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Download verification list (to exlude verification sentences from train) + veri_file_path = os.path.join( + hparams["save_folder"], os.path.basename(hparams["verification_file"]) + ) + download_file(hparams["verification_file"], veri_file_path) + + # Dataset prep (parsing VoxCeleb and annotation into csv files) + from voxceleb_prepare import prepare_voxceleb # noqa + + prepare_voxceleb( + data_folder=hparams["data_folder"], + save_folder=hparams["save_folder"], + verification_pairs_file=veri_file_path, + splits=["train", "dev", "test"], + split_ratio=[90, 10], + seg_dur=hparams["sentence_len"], + source=hparams["voxceleb_source"] + if "voxceleb_source" in hparams + else None, + ) + + # Loading wav2vec2.0 + if not hparams["pretrain"]: + run_on_main(hparams["pretrainer"].collect_files) + hparams["pretrainer"].load_collected() + + # Dataset IO prep: creating Dataset objects and proper encodings for phones + train_data, valid_data, label_encoder = dataio_prep(hparams) + + # Create experiment directory + sb.core.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Brain class initialization + speaker_brain = SpeakerBrain( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Training + speaker_brain.fit( + speaker_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["dataloader_options"], + valid_loader_kwargs=hparams["dataloader_options"], + ) + + # Now preparing for test : + hparams["device"] = speaker_brain.device + + speaker_brain.modules.eval() + train_dataloader, enrol_dataloader, test_dataloader = dataio_prep_verif( + hparams + ) + # Computing enrollment and test embeddings + logger.info("Computing enroll/test embeddings...") + + # First run + enrol_dict = compute_embedding_loop(enrol_dataloader) + test_dict = compute_embedding_loop(test_dataloader) + + if "score_norm" in hparams: + train_dict = compute_embedding_loop(train_dataloader) + + # Compute the EER + logger.info("Computing EER..") + # Reading standard verification split + with open(veri_file_path) as f: + veri_test = [line.rstrip() for line in f] + + positive_scores, negative_scores = get_verification_scores(veri_test) + del enrol_dict, test_dict + + eer, th = EER(torch.tensor(positive_scores), torch.tensor(negative_scores)) + logger.info("EER(%%)=%f", eer * 100) + + min_dcf, th = minDCF( + torch.tensor(positive_scores), torch.tensor(negative_scores) + ) + # Testing + logger.info("minDCF=%f", min_dcf * 100) diff --git a/recipes/AISHELL-1/ASR/transformer/hparams/branchformer_summarymixing.yaml b/recipes/AISHELL-1/ASR/transformer/hparams/branchformer_summarymixing.yaml deleted file mode 100644 index b483ccc..0000000 --- a/recipes/AISHELL-1/ASR/transformer/hparams/branchformer_summarymixing.yaml +++ /dev/null @@ -1,306 +0,0 @@ -# ############################################################################ -# -# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. -# -# Model: E2E ASR with Transformer -# Encoder: branchformer Encoder with SummaryMixing -# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch -# Tokens: BPE with unigram -# losses: CTC + KLdiv (Label Smoothing loss) -# Training: AISHELL-1 -# Authors: Titouan Parcollet, Shucong Zhang, Rogier van Dalen, and Sourav -# Bhattacharya -# ############################################################################ -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 3407 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/branchformer_summarymixing_large/ -cer_file: !ref /cer.txt -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e,g./path/to/aishell -data_folder_noise: !ref /noise # The noisy sequencies for data augmentation will automatically be downloaded here. -remove_compressed_wavs: False -skip_prep: False -ckpt_interval_minutes: 15 # save checkpoint every N min -train_data: !ref /train.csv -valid_data: !ref /dev.csv -test_data: !ref /test.csv -noise_annotation: !ref /noise.csv #The data manifest files are created by the data preparation script -tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt - -####################### Training Parameters #################################### - -number_of_epochs: 200 -batch_size: 16 # For SummaryMixing, we don't use static batching -ctc_weight: 0.3 -gradient_accumulation: 1 -loss_reduction: 'batchmean' -sorting: random -avg_checkpoints: 10 # Number of checkpoints to average for evaluation -precision: fp32 # bf16, fp16 or fp32 - -dynamic_batching: True -max_batch_length: 500 # in terms of "duration" in annotations by default, second here -shuffle: False # if true re-creates batches at each epoch shuffling examples. -num_buckets: 200 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1 -batch_ordering: ascending -dynamic_batch_sampler: - max_batch_length: !ref - shuffle: !ref - num_buckets: !ref - batch_ordering: !ref - -num_workers: 6 - -# stages related parameters -stage_one_epochs: 150 -lr_adam: 0.0008 -lr_sgd: 0.000025 - -# Feature parameters -sample_rate: 16000 -n_fft: 400 -n_mels: 80 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - shuffle: True - -valid_dataloader_opts: - batch_size: !ref - -test_dataloader_opts: - batch_size: !ref - -####################### Model Parameters ####################################### - -# Transformer -attention_type: SummaryMixing # SummaryMixing, regularMHA or RelPosMHAXL -mode: SummaryMixing # SummaryMixing or SummaryMixing-lite -d_model: 512 -nhead: 1 # 1 is faster but 4 gives very slightly better performance (WER) -num_encoder_layers: 18 -num_decoder_layers: 6 -decoder_linear_units: 2048 -csgu_linear_units: 3072 -csgu_kernel_size: 31 -local_proj_hid_dim: [512] -local_proj_out_dim: !ref -summary_hid_dim: [512] -summary_out_dim: !ref - -transformer_dropout: 0.1 -activation: !name:torch.nn.GELU -output_neurons: 5000 - -# Outputs -blank_index: 0 -label_smoothing: 0.1 -pad_index: 0 -bos_index: 1 -eos_index: 2 - -# Decoding parameters -min_decode_ratio: 0.0 -max_decode_ratio: 1.0 # 1.0 -valid_search_interval: 10 -valid_beam_size: 10 -test_beam_size: 10 -ctc_weight_decode: 0.40 - -############################## Models ########################################## - -CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd - input_shape: (8, 10, 80) - num_blocks: 2 - num_layers_per_block: 1 - out_channels: (64, 32) - kernel_sizes: (3, 3) - strides: (2, 2) - residuals: (False, False) - -Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length - input_size: 640 - tgt_vocab: !ref - d_model: !ref - nhead: !ref - num_encoder_layers: !ref - num_decoder_layers: !ref - dropout: !ref - activation: !ref - branchformer_activation: !ref - encoder_module: branchformer - csgu_linear_units: !ref - kernel_size: !ref - attention_type: !ref - local_proj_hid_dim: !ref - local_proj_out_dim: !ref - summary_hid_dim: !ref - summary_out_dim: !ref - mode: !ref - normalize_before: True - causal: False - -tokenizer: !new:sentencepiece.SentencePieceProcessor - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: !ref - n_neurons: !ref - -seq_lin: !new:speechbrain.nnet.linear.Linear - input_size: !ref - n_neurons: !ref - -modules: - CNN: !ref - Transformer: !ref - seq_lin: !ref - ctc_lin: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -# define two optimizers here for two-stage training -Adam: !name:torch.optim.AdamW - lr: 0 - betas: (0.9, 0.98) - eps: 0.000000001 - -SGD: !name:torch.optim.SGD - lr: !ref - momentum: 0.99 - nesterov: True - -############################## Decoding & optimiser ############################ - -# Scorer -ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer - eos_index: !ref - blank_index: !ref - ctc_fc: !ref - -scorer: !new:speechbrain.decoders.scorer.ScorerBuilder - full_scorers: [!ref ] - weights: - ctc: !ref - -valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher - modules: [!ref , !ref ] - bos_index: !ref - eos_index: !ref - min_decode_ratio: !ref - max_decode_ratio: !ref - beam_size: !ref - using_eos_threshold: False - length_normalization: True - scorer: !ref - -test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher - modules: [!ref , !ref ] - bos_index: !ref - eos_index: !ref - min_decode_ratio: !ref - max_decode_ratio: !ref - beam_size: !ref - using_eos_threshold: False - length_normalization: True - scorer: !ref - -log_softmax: !new:torch.nn.LogSoftmax - dim: -1 - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - reduction: !ref - -seq_cost: !name:speechbrain.nnet.losses.kldiv_loss - label_smoothing: !ref - reduction: !ref - -noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler - lr_initial: !ref - n_warmup_steps: 25000 - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - noam_scheduler: !ref - normalizer: !ref - counter: !ref - -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -normalize: !new:speechbrain.processing.features.InputNormalization - norm_type: global - update_until_epoch: 4 - -compute_features: !new:speechbrain.lobes.features.Fbank - sample_rate: !ref - n_fft: !ref - n_mels: !ref - - -############################## Augmentation #################################### - -# Time Drop -time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop - drop_length_low: 0 - drop_length_high: 100 - drop_count_low: 2 - drop_count_high: 2 - replace: "zeros" - dim: 1 - -# Frequency Drop -freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop - drop_length_low: 30 - drop_length_high: 40 - drop_count_low: 2 - drop_count_high: 2 - replace: "zeros" - dim: 2 - -# Time warp -time_warp: !new:speechbrain.augment.freq_domain.Warping - dim: 1 - -fea_augment: !new:speechbrain.augment.augmenter.Augmenter - concat_original: True - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 1 - max_augmentations: 1 - augment_start_index: !ref # This leaves unchanges original inputs - concat_end_index: !ref # This leaves unchanges original inputs - augment_prob: 1.0 - augmentations: [ - !ref , - !ref , - !ref ] - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -# AISHELL-1 has spaces between words in the transcripts, -# which Chinese writing normally does not do. -# If remove_spaces, spaces are removed -# from the transcript before computing CER. -remove_spaces: True -split_tokens: !apply:operator.not_ [!ref ] - -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: !ref -acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats - -pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer - collect_in: !ref - loadables: - tokenizer: !ref - paths: - tokenizer: !ref diff --git a/recipes/CommonVoice/ASR/transformer/hparams/branchformer_summarymixing.yaml b/recipes/CommonVoice/ASR/transformer/hparams/branchformer_summarymixing.yaml deleted file mode 100644 index bea1f58..0000000 --- a/recipes/CommonVoice/ASR/transformer/hparams/branchformer_summarymixing.yaml +++ /dev/null @@ -1,277 +0,0 @@ -# ############################################################################ -# -# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. -# -# Model: E2E ASR with Transformer -# Encoder: Transformer Encoder -# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch -# Tokens: unigram -# losses: CTC + KLdiv (Label Smoothing loss) -# Authors: Titouan Parcollet, Shucong Zhang, Rogier van Dalen, and Sourav -# Bhattacharya -# ############################################################################ -# Seed needs to be set at top of yaml, before objects with parameters are made -seed: 1234 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/transformer_fr/ -test_wer_file: !ref /wer_test.txt -valid_wer_file: !ref /wer_valid.txt -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Data files -data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr -train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files -dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files -test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files -accented_letters: True -language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english -train_csv: !ref /train.csv -valid_csv: !ref /dev.csv -test_csv: !ref /test.csv -skip_prep: False # Skip data preparation - -# We remove utterance slonger than 10s in the train/dev/test sets as -# longer sentences certainly correspond to "open microphones". -avoid_if_longer_than: 10.0 - -ckpt_interval_minutes: 15 # save checkpoint every N min - -####################### Training Parameters #################################### -number_of_epochs: 120 # We used 100 for italian and Dutch -batch_size: 8 # Not used when dynamic batching is activated -ctc_weight: 0.3 -grad_accumulation_factor: 2 -loss_reduction: 'batchmean' -sorting: random -precision: fp32 # bf16, fp16 or fp32 - -# stages related parameters -stage_one_epochs: 40 -lr_adam: 0.0008 # We used 0.0005 for italian and 0.0008 for Dutch -lr_sgd: 0.0001 - -# BPE parameters -token_type: unigram # ["unigram", "bpe", "char"] -character_coverage: 1.0 - -# Feature parameters -sample_rate: 16000 -n_fft: 400 -n_mels: 80 - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - shuffle: True - num_workers: 6 - -valid_dataloader_opts: - batch_size: !ref - num_workers: 6 - -test_dataloader_opts: - batch_size: !ref - num_workers: 6 - -####################### Model Parameters ########################### -# Transformer -attention_type: SummaryMixing # SummaryMixing, regularMHA or RelPosMHAXL -mode: SummaryMixing # SummaryMixing or SummaryMixing-lite -d_model: 512 -nhead: 1 # 1 is faster but 4 gives very slightly better performance (WER) -num_encoder_layers: 18 -num_decoder_layers: 6 -decoder_linear_units: 2048 -csgu_linear_units: 3072 -csgu_kernel_size: 31 -local_proj_hid_dim: [512] -local_proj_out_dim: !ref -summary_hid_dim: [512] -summary_out_dim: !ref -transformer_dropout: 0.1 -activation: !name:torch.nn.GELU -output_neurons: 1000 # We used 350 for Dutch and 1000 for italian - -# Outputs -blank_index: 0 -label_smoothing: 0.0 -pad_index: 0 -bos_index: 1 -eos_index: 2 - -# Decoding parameters -min_decode_ratio: 0.0 -max_decode_ratio: 1.0 -valid_search_interval: 10 -valid_beam_size: 10 -test_beam_size: 80 -ctc_weight_decode: 0.3 -scorer_beam_scale: 0.3 - -############################## models ################################ - -CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd - input_shape: (8, 10, 80) - num_blocks: 2 - num_layers_per_block: 1 - out_channels: (64, 32) - kernel_sizes: (3, 3) - strides: (2, 2) - residuals: (False, False) - -Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length - input_size: 640 - tgt_vocab: !ref - d_model: !ref - nhead: !ref - num_encoder_layers: !ref - num_decoder_layers: !ref - dropout: !ref - activation: !ref - branchformer_activation: !ref - encoder_module: branchformer - csgu_linear_units: !ref - kernel_size: !ref - attention_type: !ref - local_proj_hid_dim: !ref - local_proj_out_dim: !ref - summary_hid_dim: !ref - summary_out_dim: !ref - mode: !ref - normalize_before: True - causal: False - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: !ref - n_neurons: !ref - -seq_lin: !new:speechbrain.nnet.linear.Linear - input_size: !ref - n_neurons: !ref - -modules: - CNN: !ref - Transformer: !ref - seq_lin: !ref - ctc_lin: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - -# We define two optimizers as we have two stages (training + finetuning) -Adam: !name:torch.optim.AdamW - lr: !ref - betas: (0.9, 0.98) - eps: 0.000000001 - -SGD: !name:torch.optim.SGD - lr: !ref - momentum: 0.99 - nesterov: True - -# Scorer -ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer - eos_index: !ref - blank_index: !ref - ctc_fc: !ref - -scorer: !new:speechbrain.decoders.scorer.ScorerBuilder - full_scorers: [!ref ] - weights: - ctc: !ref - scorer_beam_scale: !ref - -valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher - modules: [!ref , !ref ] - bos_index: !ref - eos_index: !ref - min_decode_ratio: !ref - max_decode_ratio: !ref - beam_size: !ref - using_eos_threshold: False - length_normalization: True - scorer: !ref - -test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher - modules: [!ref , !ref ] - bos_index: !ref - eos_index: !ref - min_decode_ratio: !ref - max_decode_ratio: !ref - beam_size: !ref - temperature: 1.15 - using_eos_threshold: True - scorer: !ref - -log_softmax: !new:torch.nn.LogSoftmax - dim: -1 - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - reduction: !ref - -seq_cost: !name:speechbrain.nnet.losses.kldiv_loss - label_smoothing: !ref - reduction: !ref - -noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler - lr_initial: !ref - n_warmup_steps: 25000 - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - noam_scheduler: !ref - normalizer: !ref - counter: !ref - -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -normalize: !new:speechbrain.processing.features.InputNormalization - norm_type: global - update_until_epoch: 3 - -############################## Augmentations ################################### - -# Time Drop -time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop - drop_length_low: 15 - drop_length_high: 25 - drop_count_low: 5 - drop_count_high: 5 - -# Frequency Drop -freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop - drop_length_low: 25 - drop_length_high: 35 - drop_count_low: 2 - drop_count_high: 2 - dim: 2 - -# Time warp -time_warp: !new:speechbrain.augment.freq_domain.Warping - -fea_augment: !new:speechbrain.augment.augmenter.Augmenter - min_augmentations: 3 - max_augmentations: 3 - augment_prob: 1.0 - augmentations: [ - !ref , - !ref , - !ref ] - -compute_features: !new:speechbrain.lobes.features.Fbank - sample_rate: !ref - n_fft: !ref - n_mels: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats -acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats -cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats - split_tokens: True diff --git a/recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml b/recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml new file mode 100644 index 0000000..ee0ec32 --- /dev/null +++ b/recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams/summarymixing_wav2vec2.yaml @@ -0,0 +1,180 @@ +# ################################ +# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +# This is the pre-training configurations of SummaryMixing wav2vec 2.0. + +# Usage: Install SpeechBrain +# Create a folder recipes/Libri-Light/self-supervised-learning/wav2vec2 +# Copy this file under recipes/Libri-Light/self-supervised-learning/wav2vec2/hparams +# SummaryMixing: https://arxiv.org/abs/2307.07421 +# SummaryMixing SSL: + +# Authors +# * Titouan Parcollet 2023, 2024 +# * Shucong Zhang 2023, 2024 +# * Rogier van Dalen 2023, 2024 +# * Sourav Bhattacharya 2023, 2024 +# ################################ + +# NOTE! to pre-train on Libri-Light, you need to firstly download the Libri-Light data +# through the toolkit in the Libri-Light github repo +# then, use the vad script from the Libri-Light github repo to apply vad to the data +# then, use "make_librilight_csv.py" to make the csv file for SpeechBrain pretraining +# after prepared the data and the csv file, use "python train_sb_wav2vec2_mel.py hparams/summarymixing_wav2vec2.yaml" + +data_folder: # path to the libri-light folder +output_folder: results/summix_w2v2/libri-light-medium +save_folder: !ref /save +# Logging file for every N optimizer steps (many lines) +train_steps_log: !ref /train_steps_log.txt +# Logging file per epoch +train_stage_log: !ref /train_stage_log.txt + +train_csv: # path to train.csv NOTE! please refer to make_librilight_csv.py to see how to create train.csv for libri-light +valid_csv: # path to LibriSpeech dev-clean.csv We use dev-clean to monitor the pre-training, not to do model selection +skip_prep: True # train_csv and valid_csv should be created before runing the SSL pre-training + +avoid_if_longer_than: 30.0 +avoid_if_shorter_than: 10.0 # after vad most utterances are quite long +log_interval: 1000 # Logging every N optimizer steps +precision: fp16 # bf16, fp16 or fp32 +max_grad_norm: 100. + +# The training will either stops at number_of_epochs or optimizer_step_limit +# I.e. the first that is reached. +number_of_epochs: 3000 +optimizer_step_limit: 300000 + +# Dynamic Batching parameters +max_batch_length: 360 +num_buckets: 70 +shuffle: True # if true re-creates batches at each epoch shuffling examples. +batch_ordering: random +grad_accumulation_factor: 4 # we use 4 GPUs. 360s batch * 4 grad_acc * 4 GPUs = 1.6h +dynamic_batch_sampler_train: + max_batch_length: !ref + num_buckets: !ref + shuffle: !ref + batch_ordering: !ref + +train_dataloader_options: + num_workers: 4 + +test_dataloader_options: + batch_size: 8 # DynamicBatching not used at testing time + num_workers: 4 + +# Training parameters +lr: 0.0005 +warmup: 30000 +# This is equivalent to optimizer_step_limit - warmup +# Necessary to do to have a linear warmup and linear decay directly +# If cooldown < optimizer_step_limit - warmup then a third step with a slower +# decay is applied in the middle (see the implementation of the scheduler) +cooldown: 270000 + +# Loss parameters +diversity_loss_weight: 0.1 +mask_prob: 0.65 +mask_length: 10 +num_negatives: 100 + +# Model parameters +frontend_type: mel_cnn_base +embedding_dim: 768 +extractor_dim: 512 +final_dim: 256 +encoder_layerdrop: 0.05 +latentextractor_kernels: [3, 3] +latentextractor_strides: [2, 1] + +# Feature parameters +sample_rate: 16000 +n_fft: 400 +n_mels: 80 +hop_length: 10 +win_length: 25 + +optimizer: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.05 + eps: 0.000001 + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor + kernel_sizes: !ref + strides: !ref + out_channels: [!ref , !ref ] + input_dim: 80 + +encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder + d_model: !ref + num_layers: 12 + nhead: 8 + d_ffn: 3072 + dropout: 0.1 + layerdrop_prob: !ref + attention_type: SummaryMixing #SummaryMixing regularMHA + local_proj_hid_dim: [!ref ] + local_proj_out_dim: !ref + summary_hid_dim: [!ref ] + mode: "SummaryMixing" + +encoder_wrapper: !new:speechbrain.lobes.models.wav2vec.EncoderWrapper + in_dim: !ref + embedding_dim: !ref + latent_encoder: !ref + dropout_encoder_input: 0.1 + +target_quantiser: !new:speechbrain.lobes.models.wav2vec.W2VTargetQuantiser + in_dim: !ref + out_dim: !ref + +feat_proj: !new:torch.nn.Linear + in_features: !ref + out_features: !ref + +modules: + compute_features: !ref + normalize: !ref + CNN: !ref + latent_encoder: !ref + feat_proj: !ref + target_quantiser: !ref + +loss: !new:speechbrain.nnet.losses.ContrastiveLoss + logit_temp: 0.1 + +lr_scheduler: !new:speechbrain.nnet.schedulers.WarmCoolDecayLRSchedule + lr: !ref + warmup: !ref + cooldown: !ref + total_steps: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: sentence + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + hop_length: !ref + win_length: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + CNN: !ref + latent_encoder: !ref + feat_proj: !ref + target_quantiser: !ref + scheduler: !ref + counter: !ref + +train_steps_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +train_stage_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref \ No newline at end of file diff --git a/recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py b/recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py new file mode 100644 index 0000000..24bef85 --- /dev/null +++ b/recipes/Libri-Light/self-supervised-learning/wav2vec2/make_librilight_csv.py @@ -0,0 +1,90 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +Usage: Install SpeechBrain + Create a folder recipes/Libri-Light/self-supervised-learning/wav2vec2 + Copy this file under recipes/Libri-Light/self-supervised-learning/wav2vec2 + +This is script create speebrain csv files for the Libri-Light dataset +1. download the Libri-Light dataset through the toolkit in the Libri-Light github repo +2. use the vad script from the Libri-Light repo to do the vad +3. use "python make_librilight_csv.py path_to_vad_output path_to_save_csv" to generate the train.csv for the SSL pretraining + + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: + +Authors + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 +""" + + +import pathlib +import torchaudio +import tqdm +import multiprocessing +import csv +from pathlib import Path +import sys +import os + +def make_csv_for_each(subpath_1_csv_file_folder, max_length=20.2): + subpath_1, csv_file_folder = subpath_1_csv_file_folder + for i, flac_file in enumerate(subpath_1.glob('**/*.flac')): + flac_file_name = flac_file.stem + waveform, sample_rate = torchaudio.load(str(flac_file)) + num_frames = waveform.size(1) + duration_seconds = num_frames / sample_rate + if duration_seconds > max_length: + continue + audio_length_seconds = waveform.shape[1] / sample_rate + csv_file = f"{csv_file_folder}/{flac_file.parent.stem}.csv" + with open(csv_file, mode='a', newline='') as csvfile: + csv_writer = csv.writer( + csvfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + csv_writer.writerow([flac_file_name,audio_length_seconds,str(flac_file)]) + + +def processes_folder(data_path, csv_file_folder): + os.makedirs(csv_file_folder, exist_ok=True) + os.makedirs(f"{csv_file_folder}/tmp", exist_ok=True) + list_dir = pathlib.Path(data_path) + tasks = [] + for x in list_dir.iterdir(): + tasks.append((x, f"{csv_file_folder}/tmp")) + with multiprocessing.Pool(processes=128) as pool: + for _ in tqdm.tqdm(pool.imap_unordered(make_csv_for_each, tasks), total=len(tasks)): + pass + +def merge_csv_files(csv_file_folder): + file_list = [str(x) for x in Path(f"{csv_file_folder}/tmp").glob('*.csv')] + output_file = f"{csv_file_folder}/train.csv" + fieldnames = ["ID", "duration", "wav"] + + with open(output_file, mode='a', newline='', encoding='utf-8') as outfile: + csv_writer = csv.writer( + outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + csv_writer.writerow(fieldnames) + for file_path in tqdm.tqdm(file_list): + with open(file_path, mode='r', encoding='utf-8') as infile: + reader = csv.reader(infile) + + # filter out bad rows + for row in reader: + if len(row) == 3 and os.path.exists(row[-1]): + new_row = [row[-1], row[1], row[2]] + csv_writer.writerow(new_row) + else: + print(f"bad row {row}") + + import shutil + shutil.rmtree(f"{csv_file_folder}/tmp") + + +if __name__ == "__main__": + processes_folder(sys.argv[1], sys.argv[2]) + merge_csv_files(sys.argv[2]) diff --git a/recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py b/recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py new file mode 100644 index 0000000..35d20dc --- /dev/null +++ b/recipes/Libri-Light/self-supervised-learning/wav2vec2/train_sb_wav2vec2_mel.py @@ -0,0 +1,478 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +This is the pre-training recipes of SummaryMixing wav2vec 2.0. + +Usage: Install SpeechBrain + Create a folder recipes/Libri-Light/self-supervised-learning/wav2vec2 + Copy this file under recipes/Libri-Light/self-supervised-learning/wav2vec2 + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: + +Authors + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 +""" + +import logging +import sys +import time +from functools import partial + +import speechbrain as sb +import torch +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel +from hyperpyyaml import load_hyperpyyaml + +from speechbrain import Stage +from speechbrain.utils.distributed import run_on_main +from speechbrain.dataio.dataloader import SaveableDataLoader +from speechbrain.dataio.sampler import DynamicBatchSampler +from speechbrain.lobes.models.wav2vec import w2v_mask_collate_fn +from speechbrain.lobes.models.wav2vec import sample_negatives +from speechbrain.core import AMPConfig + +logger = logging.getLogger(__name__) + + +class W2V2Brain(sb.core.Brain): + def compute_forward(self, batch, stage): + """Computes forward pass through wav2vec model and returns encoded and + target embeddings as well as other metrics of interest. + """ + wavs, wav_lens, mask = batch + wavs, wav_lens, mask = ( + wavs.to(self.device), + wav_lens.to(self.device), + mask.to(self.device), + ) + batch_size = wavs.size(0) + + # Mormalisation already done in dataloader + # 1. Go through features extractor + if ( + self.hparams.frontend_type == "w2v2" + or self.hparams.frontend_type == "sew" + ): + latents = self.modules.latent_extractor( + wavs, normalize_signal=False + ) + elif self.hparams.frontend_type == "mel": + with torch.no_grad(): + latents = self.modules.compute_features(wavs) + latents = self.modules.normalize( + latents, wav_lens, epoch=current_epoch + ).detach() + elif self.hparams.frontend_type == "mel_v2": + with torch.no_grad(): + latents = self.modules.compute_features(wavs) + latents = self.modules.normalize( + latents, wav_lens, epoch=current_epoch + ).detach() + elif self.hparams.frontend_type == "mel_pool": + with torch.no_grad(): + latents = self.modules.compute_features(wavs) + latents = self.modules.pooling(latents) + latents = self.modules.normalize( + latents, wav_lens, epoch=current_epoch + ).detach() + elif ( + self.hparams.frontend_type == "mel_cnn" + or self.hparams.frontend_type == "mel_cnn_base" + ): + with torch.no_grad(): + latents = self.modules.compute_features(wavs) + + latents = self.modules.normalize( + latents, wav_lens, + ).detach() + latents = self.modules.CNN(latents) + latents = latents.view(batch_size, latents.shape[1], -1) + elif self.hparams.frontend_type == "sincnet": + latents = self.modules.compute_features(wavs) + latents = self.modules.CNN(latents) + + elif self.hparams.frontend_type == "fast_audio": + latents = self.modules.compute_features(wavs) + latents = self.modules.normalize(latents, wav_lens) + latents = latents.view(batch_size, latents.shape[1], -1) + elif self.hparams.frontend_type == "leaf": + latents = self.modules.compute_features(wavs) + + + # 2. Go through latent (Transformer). + results = self.modules.latent_encoder( + latents, mask=mask, wav_lens=wav_lens, + ) + + embeddings = results["embeddings"] + + # 3. Mask some of the latent and projection + embeddings = embeddings[mask] + embeddings = self.modules.feat_proj(embeddings) + results["embeddings"] = embeddings.view( + batch_size, -1, embeddings.size(1) + ) + + latents = latents[mask].view(batch_size, -1, latents.size(2)) + + # 4. Apply the quantiser as well + targets, meta = self.modules.target_quantiser(latents) + results.update(meta) + results["targets"] = targets + return results + + def compute_objectives(self, forward_outputs, batch, stage): + """Samples negatives, computes contrastive loss and accuracy. + """ + + embeddings = forward_outputs["embeddings"] + targets = forward_outputs["targets"] + + negs = sample_negatives(targets, self.hparams.num_negatives) + + loss, accuracy = self.hparams.loss(embeddings, targets, negs) + + # This is only used for logging purpose + if stage != sb.Stage.TRAIN: + self.acc_metric.append(accuracy) + + objectives = { + "loss": loss, + "accuracy": accuracy, + "num_masked": forward_outputs["num_masked"], + "ratio_masked": forward_outputs["ratio_masked"], + } + if ( + "diversity_loss" in forward_outputs + ): # only quantised model has these + objectives.update( + { + "diversity_loss": forward_outputs["diversity_loss"], + "prob_perplex": forward_outputs["prob_perplex"], + "code_perplex": forward_outputs["code_perplex"], + "num_vars": forward_outputs["num_vars"], + "temp": forward_outputs["temp"], + } + ) + + # Compute the loss given the original equation from the paper + loss = objectives["loss"] + if self.hparams.diversity_loss_weight == 0.0: + objectives["backprop_loss"] = loss + else: + objectives["backprop_loss"] = ( + loss + + objectives["diversity_loss"] + * self.hparams.diversity_loss_weight + * objectives["num_masked"] + ) + return objectives + + def fit_batch(self, batch): + amp = AMPConfig.from_name(self.precision) + should_step = (self.step % self.grad_accumulation_factor) == 0 + + # Managing automatic mixed precision + with self.no_sync(not should_step): + if self.use_amp: + with torch.autocast( + dtype=amp.dtype, device_type=torch.device(self.device).type, + ): + outputs = self.compute_forward(batch, Stage.TRAIN) + objectives = self.compute_objectives( + outputs, batch, Stage.TRAIN + ) + else: + outputs = self.compute_forward(batch, Stage.TRAIN) + objectives = self.compute_objectives( + outputs, batch, Stage.TRAIN + ) + + self.scaler.scale( + objectives["backprop_loss"] / self.grad_accumulation_factor + ).backward() + + objectives["total_loss"] = objectives["backprop_loss"].detach() + + if should_step: + self.optimizers_step() + self.on_fit_batch_end(objectives) + + return objectives["backprop_loss"].detach() + + def on_fit_batch_end(self, objectives): + """ Called after fit_batch(), updates learning rate and does per-step logging. """ + if isinstance(self.modules.target_quantiser, DistributedDataParallel): + w2v_model = self.modules.target_quantiser.module + else: + w2v_model = self.modules.target_quantiser + + w2v_model.quantiser.update_temp(self.optimizer_step) + + self.hparams.lr_scheduler(self.optimizer, self.optimizer_step) + + # Perform step-wise logging + if ( + hasattr(self.hparams, "log_interval") + and self.optimizer_step % self.hparams.log_interval == 0 + ): + + # Create a dictionary and fill it with everything we + # want to log such as contrastive loss, diversity loss, + # learning rate etc. + log_dct = { + k: (v.item() if isinstance(v, torch.Tensor) else v) + for k, v in objectives.items() + } + current_lr = self.optimizer.param_groups[0]["lr"] + log_dct["steps"] = self.optimizer_step + log_dct["lr"] = current_lr + log_dct["avg_loss"] = self.avg_train_loss + + if hasattr(self, "time_last_log"): + run_time_since_last_log = time.time() - self.time_last_log + log_dct["run_time"] = run_time_since_last_log + self.time_last_log = time.time() + + if sb.utils.distributed.if_main_process(): + self.hparams.train_steps_logger.log_stats(stats_meta=log_dct,) + + def evaluate_batch(self, batch, stage): + """ Returns accuracy on contrastive objective. """ + out = self.compute_forward(batch, stage=stage) + objectives = self.compute_objectives(out, batch, stage=stage) + return objectives["backprop_loss"].detach().cpu() + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.acc_metric = [] + + def on_stage_end(self, stage, stage_loss, epoch=None): + + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + + if stage == sb.Stage.VALID: + print(self.acc_metric) + stage_stats["accuracy"] = sum(self.acc_metric) / len( + self.acc_metric + ) + + self.hparams.train_stage_logger.log_stats( + stats_meta={ + "epoch": epoch, + "steps": self.optimizer_step, + "lr": self.optimizer.param_groups[0]["lr"], + }, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + + self.checkpointer.save_and_keep_only( + end_of_epoch=True, + num_to_keep=5, + meta={"valid_loss": stage_loss}, + ) + + +def dataio_prepare(hparams): + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + # We remove longer and shorter files from the train. + train_data = train_data.filtered_sorted( + sort_key="duration", + key_max_value={"duration": hparams["avoid_if_longer_than"]}, + key_min_value={"duration": hparams["avoid_if_shorter_than"]}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + + datasets = [train_data, valid_data] + + def get_output_lengths(input_lengths): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + for kernel_size, stride in zip( + hparams["latentextractor_kernels"], + hparams["latentextractor_strides"], + ): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + return input_lengths.to(torch.long) + + def get_output_lengths_w2v2_mel(input_lengths, hop_length): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + input_lengths = torch.floor( + ((input_lengths / 16000) / (hop_length / 1000) + 1) + ).to(torch.long) + for kernel_size, stride in zip( + hparams["latentextractor_kernels"], + hparams["latentextractor_strides"], + ): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths.to(torch.long) + + def get_output_lengths_mel(input_lengths, hop_length): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + + input_lengths = torch.floor( + ((input_lengths / 16000) / (hop_length / 1000) + 1) + ).to(torch.long) + + return input_lengths + + def get_output_lengths_mel_cnn(input_lengths, hop_length): + """ Function to get the output length of the feature extractor this is + necessery to compute the masks of wav2vec2. + """ + # 2D CNN 32 31 + # input_lengths = torch.floor( + # ((input_lengths / 16000) / (hop_length / 1000) / 2 + 1) + # ).to(torch.long) + + # 2D CNN 32 32 + input_lengths = torch.floor( + ((input_lengths / 16000) / (hop_length / 1000) / 4 + 1) + ).to(torch.long) + return input_lengths + + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + assert sig.dim() == 1, sig.dim() + + # Audio normalization + with torch.no_grad(): + sig = F.layer_norm(sig, sig.shape) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + sb.dataio.dataset.set_output_keys(datasets, ["id", "sig"]) + + # We create the DynamicBatch Sampler + dynamic_hparams = hparams["dynamic_batch_sampler_train"] + + train_sampler = DynamicBatchSampler( + train_data, **dynamic_hparams, length_func=lambda x: x["duration"], + ) + + # We define the custom collation function that is necessary for w2v2 to + # generate masks. + if hparams["frontend_type"] == "w2v2": + w2v_mask_collate_fn_partial = partial( + w2v_mask_collate_fn, + get_out_len_fn=get_output_lengths, + mask_prob=hparams["mask_prob"], + mask_length=hparams["mask_length"], + ) + elif hparams["frontend_type"] == "mel_cnn": + w2v_mask_collate_fn_partial = partial( + w2v_mask_collate_fn, + get_out_len_fn=get_output_lengths_mel_cnn, + hop_length=hparams["hop_length"], + mask_prob=hparams["mask_prob"], + mask_length=hparams["mask_length"], + ) + elif hparams["frontend_type"] == "mel_cnn_base": + w2v_mask_collate_fn_partial = partial( + w2v_mask_collate_fn, + hop_length=hparams["hop_length"], + get_out_len_fn=get_output_lengths_w2v2_mel, + mask_prob=hparams["mask_prob"], + mask_length=hparams["mask_length"], + ) + else: + w2v_mask_collate_fn_partial = partial( + w2v_mask_collate_fn, + get_out_len_fn=get_output_lengths_mel, + hop_length=hparams["hop_length"], + mask_prob=hparams["mask_prob"], + mask_length=hparams["mask_length"], + ) + + train_loader_kwargs = { + "batch_sampler": train_sampler, + "collate_fn": w2v_mask_collate_fn_partial, + "num_workers": hparams["train_dataloader_options"]["num_workers"], + "pin_memory": True, + } + + valid_loader = SaveableDataLoader( + valid_data, + collate_fn=w2v_mask_collate_fn_partial, + num_workers=hparams["test_dataloader_options"]["num_workers"], + batch_size=hparams["test_dataloader_options"]["batch_size"], + pin_memory=True, + ) + + return train_data, valid_loader, train_loader_kwargs + + +def main(): + logger.setLevel(logging.INFO) + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + hparams.update(run_opts) + + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Update precision to bf16 if the device is CPU and precision is fp16 + if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16": + hparams["precision"] = "bf16" + + # Part that matters starts here. + train_dataset, valid_loader, train_loader_kwargs = dataio_prepare(hparams) + + brain = W2V2Brain( + modules=hparams["modules"], + opt_class=hparams["optimizer"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + brain.fit( + brain.hparams.epoch_counter, + train_dataset, + valid_loader, + train_loader_kwargs=train_loader_kwargs, + progressbar=True, + ) + + +if __name__ == "__main__": + main() diff --git a/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_summarymixing.yaml b/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_summarymixing.yaml deleted file mode 100644 index c6c2dbf..0000000 --- a/recipes/LibriSpeech/ASR/transformer/hparams/branchformer_summarymixing.yaml +++ /dev/null @@ -1,359 +0,0 @@ -# ############################################################################ -# -# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. -# -# Model: E2E ASR with Transformer -# Encoder: branchformer Encoder with SummaryMixing -# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM -# Tokens: unigram -# losses: CTC + KLdiv (Label Smoothing loss) -# Training: Librispeech 960h -# Authors: Titouan Parcollet, Shucong Zhang, Rogier van Dalen, and Sourav -# Bhattacharya -# ############################################################################ -# Seed needs to be set at top of yaml, before objects with parameters are made - -seed: 3407 -__set_seed: !apply:torch.manual_seed [!ref ] -output_folder: !ref results/branchformer_summarymixing/ -output_wer_folder: !ref / -save_folder: !ref /save -train_log: !ref /train_log.txt - -# Language model (LM) pretraining -# NB: To avoid mismatch, the speech recognizer must be trained with the same -# tokenizer used for LM training. Here, we download everything from the -# speechbrain HuggingFace repository. However, a local path pointing to a -# directory containing the lm.ckpt and tokenizer.ckpt may also be specified -# instead. E.g if you want to use your own LM / tokenizer. -pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech - -# Data files -data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech -# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES -# then data_folder_rirs should be /localscratch/xxx_corpus -# otherwise the dataset will automatically be downloaded -# data_folder_rirs: !ref -train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] -dev_splits: ["dev-clean"] -test_splits: ["test-clean", "test-other"] -skip_prep: False -train_csv: !ref /train.csv -valid_csv: !ref /dev-clean.csv -test_csv: - - !ref /test-clean.csv - - !ref /test-other.csv - -####################### Training Parameters #################################### - -# To make Transformers converge, the global bath size should be large enough. -# The global batch size is computed as batch_size * n_gpus * gradient_accumulation. -# Empirically, we found that this value should be >= 128. -# Please, set your parameters accordingly. -number_of_epochs: 120 -ctc_weight: 0.3 -grad_accumulation_factor: 1 -max_grad_norm: 5.0 -loss_reduction: 'batchmean' -sorting: random -num_workers: 8 -precision: fp32 # bf16, fp16 or fp32 -avg_checkpoints: 10 # Number of checkpoints to average for evaluation - -# stages related parameters -lr_adam: 0.0005 -weight_decay: 0.01 - -# Feature parameters -sample_rate: 16000 -n_fft: 512 -n_mels: 80 -win_length: 32 - -# This setup works well for A100 80GB GPU, adapts it to your needs. -# Or turn it off (but training speed will decrease) -dynamic_batching: True -batch_size: 16 # To be only used if dynamic batching is false -max_batch_length_train: 500 -max_batch_length_val: 100 # we reduce it as the beam is much wider (VRAM) -num_bucket: 200 -shuffle: True # if true re-creates batches at each epoch shuffling examples. -max_batch_ex: 128 -batch_ordering: random -dynamic_batch_sampler_train: - max_batch_length: !ref - num_buckets: !ref - shuffle: !ref - batch_ordering: !ref - max_batch_ex: !ref - -dynamic_batch_sampler_valid: - max_batch_length: !ref - num_buckets: !ref - shuffle: !ref - batch_ordering: !ref - max_batch_ex: !ref - -# Dataloader options -train_dataloader_opts: - batch_size: !ref - shuffle: True - num_workers: !ref - -valid_dataloader_opts: - batch_size: 1 - -test_dataloader_opts: - batch_size: 1 - -####################### Model Parameters ####################################### - -# Transformer -attention_type: SummaryMixing # SummaryMixing, regularMHA or RelPosMHAXL -mode: SummaryMixing # SummaryMixing or SummaryMixing-lite -d_model: 512 -nhead: 1 # 1 is faster but 4 gives very slightly better performance (WER) -num_encoder_layers: 18 -num_decoder_layers: 6 -decoder_linear_units: 2048 -csgu_linear_units: 3072 -csgu_kernel_size: 31 -local_proj_hid_dim: [512] -local_proj_out_dim: 512 -summary_hid_dim: [512] -summary_out_dim: 512 -transformer_dropout: 0.1 -activation: !name:torch.nn.GELU -output_neurons: 5000 - -# Outputs -blank_index: 0 -label_smoothing: 0.0 -pad_index: 0 -bos_index: 1 -eos_index: 2 - -# Decoding parameters -min_decode_ratio: 0.0 -max_decode_ratio: 1.0 -valid_search_interval: 10 -valid_beam_size: 10 -test_beam_size: 66 -lm_weight: 0.60 -ctc_weight_decode: 0.40 - -############################## Models ########################################## - -CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd - input_shape: (8, 10, 80) - num_blocks: 2 - num_layers_per_block: 1 - out_channels: (64, 32) - kernel_sizes: (3, 3) - strides: (2, 2) - residuals: (False, False) - -Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length - input_size: 640 - tgt_vocab: !ref - d_model: !ref - nhead: !ref - num_encoder_layers: !ref - num_decoder_layers: !ref - dropout: !ref - activation: !ref - branchformer_activation: !ref - encoder_module: branchformer - csgu_linear_units: !ref - kernel_size: !ref - attention_type: !ref - local_proj_hid_dim: !ref - local_proj_out_dim: !ref - summary_hid_dim: !ref - summary_out_dim: !ref - mode: !ref - normalize_before: True - causal: False - -# This is the TransformerLM that is used according to the Huggingface repository -# Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path -# For more details about the model! -# NB: It has to match the pre-trained TransformerLM!! -lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length - vocab: !ref - d_model: 768 - nhead: 12 - num_encoder_layers: 12 - num_decoder_layers: 0 - d_ffn: 3072 - dropout: 0.0 - activation: !name:torch.nn.GELU - normalize_before: False - -tokenizer: !new:sentencepiece.SentencePieceProcessor - -ctc_lin: !new:speechbrain.nnet.linear.Linear - input_size: !ref - n_neurons: !ref - -seq_lin: !new:speechbrain.nnet.linear.Linear - input_size: !ref - n_neurons: !ref - -normalize: !new:speechbrain.processing.features.InputNormalization - norm_type: global - update_until_epoch: 4 - -modules: - CNN: !ref - Transformer: !ref - seq_lin: !ref - ctc_lin: !ref - normalize: !ref - -model: !new:torch.nn.ModuleList - - [!ref , !ref , !ref , !ref ] - - -############################## Decoding & optimiser ############################ - -Adam: !name:torch.optim.AdamW - lr: !ref - betas: (0.9, 0.98) - eps: 0.000000001 - weight_decay: !ref - -# Scorer -ctc_scorer: !new:speechbrain.decoders.scorer.CTCScorer - eos_index: !ref - blank_index: !ref - ctc_fc: !ref - - -transformerlm_scorer: !new:speechbrain.decoders.scorer.TransformerLMScorer - language_model: !ref - temperature: 1.15 - -scorer_valid_search: !new:speechbrain.decoders.scorer.ScorerBuilder - full_scorers: [!ref ] - weights: - ctc: !ref - -scorer_test_search: !new:speechbrain.decoders.scorer.ScorerBuilder - full_scorers: [!ref , !ref ] - weights: - ctc: !ref - transformerlm: !ref - -valid_search: !new:speechbrain.decoders.S2STransformerBeamSearcher - modules: [!ref , !ref ] - bos_index: !ref - eos_index: !ref - min_decode_ratio: !ref - max_decode_ratio: !ref - beam_size: !ref - using_eos_threshold: False - length_normalization: True - scorer: !ref - -test_search: !new:speechbrain.decoders.S2STransformerBeamSearcher - modules: [!ref , !ref ] - bos_index: !ref - eos_index: !ref - min_decode_ratio: !ref - max_decode_ratio: !ref - beam_size: !ref - temperature: 1.15 - using_eos_threshold: False - length_normalization: True - scorer: !ref - -log_softmax: !new:torch.nn.LogSoftmax - dim: -1 - -ctc_cost: !name:speechbrain.nnet.losses.ctc_loss - blank_index: !ref - reduction: !ref - -seq_cost: !name:speechbrain.nnet.losses.kldiv_loss - label_smoothing: !ref - reduction: !ref - -noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler - lr_initial: !ref - n_warmup_steps: 30000 - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - noam_scheduler: !ref - normalizer: !ref - counter: !ref - -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -############################## Augmentation #################################### - -# Speed perturbation -speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb - speeds: [95, 100, 105] - -# Time Drop -time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop - drop_length_low: 15 - drop_length_high: 25 - drop_count_low: 4 - drop_count_high: 4 - replace: "mean" - dim: 1 - -# Frequency Drop -freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop - drop_length_low: 10 - drop_length_high: 20 - drop_count_low: 4 - drop_count_high: 4 - replace: "mean" - dim: 2 - -# Time warp -time_warp: !new:speechbrain.augment.freq_domain.Warping - dim: 1 - -fea_augment: !new:speechbrain.augment.augmenter.Augmenter - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 3 - max_augmentations: 3 - augment_prob: 1.0 - augmentations: [ - !ref , - !ref , - !ref ] - -compute_features: !new:speechbrain.lobes.features.Fbank - sample_rate: !ref - n_fft: !ref - win_length: !ref - n_mels: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref - -error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats -acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats - -# The pretrainer allows a mapping between pretrained files and instances that -# are declared in the yaml. E.g here, we will download the file lm.ckpt -# and it will be loaded into "lm" which is pointing to the defined -# before. -pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer - collect_in: !ref - loadables: - lm: !ref - tokenizer: !ref - paths: - lm: !ref /lm.ckpt - tokenizer: !ref /tokenizer.ckpt diff --git a/speechbrain/lobes/models/VanillaNN.py b/speechbrain/lobes/models/VanillaNN.py index 1c786ab..f937392 100644 --- a/speechbrain/lobes/models/VanillaNN.py +++ b/speechbrain/lobes/models/VanillaNN.py @@ -7,10 +7,10 @@ Source: https://arxiv.org/abs/2307.07421 Authors - * Titouan Parcollet 2023 - * Shucong Zhang 2023 - * Rogier van Dalen 2023 - * Sourav Bhattacharya 2023 + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 """ import torch import math diff --git a/speechbrain/lobes/models/transformer/Branchformer.py b/speechbrain/lobes/models/transformer/Branchformer.py deleted file mode 100644 index d8eaf2c..0000000 --- a/speechbrain/lobes/models/transformer/Branchformer.py +++ /dev/null @@ -1,489 +0,0 @@ -""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. - -This library connects SummaryMixing to the standard SpeechBrain lobes for Branchformer ASR. -Large parts of the code come from the SpeechBrain repository. - -Usage: Install SpeechBrain and copy this file under speechbrain/lobes/models/transformer/ -Source: https://arxiv.org/abs/2307.07421 - -Authors - * Titouan Parcollet 2023 - * Shucong Zhang 2023 - * Rogier van Dalen 2023 - * Sourav Bhattacharya 2023 -""" - -import torch -import torch.nn as nn -from typing import Optional - -from speechbrain.nnet.attention import ( - RelPosMHAXL, - MultiheadAttention, -) -from speechbrain.nnet.hypermixing import HyperMixing -from speechbrain.nnet.normalization import LayerNorm -from speechbrain.lobes.models.convolution import ConvolutionalSpatialGatingUnit -from speechbrain.lobes.models.VanillaNN import VanillaNN -from speechbrain.nnet.summary_mixing import SummaryMixing - - -class ConvolutionBranch(nn.Module): - """This is an implementation of the convolution branch in Branchformer. - - The default structure is: - Channel Proj -> GeLU -> (CNN Spatial Gating) -> Channel Proj -> Dropout - - Arguments - ---------- - input_size : int - The expected size of the feature (channel) dimension. - linear_units: int, optional - Number of neurons in the hidden linear units. - kernel_size: int, optional - Kernel size of non-bottleneck convolutional layer. - activation: torch.nn.Module, optional - Activation function used after pre projection. - gate_activation: torch.nn.Module, optional - Activation function used at the gate of the CSGU module. - dropout: float, optional - Dropout rate. - use_linear_after_conv: bool, optional - If True, will apply a linear transformation of size input_size//2 - - Example - ------- - >>> x = torch.rand((8, 60, 512)) - >>> net = ConvolutionBranch(512, 1024) - >>> output = net(x) - >>> output.shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - input_size, - linear_units=3072, - kernel_size=31, - activation=nn.GELU, - gate_activation=nn.Identity, - dropout=0.0, - use_linear_after_conv=False, - ): - super().__init__() - - self.pre_channel_proj = nn.Linear(input_size, linear_units) - self.post_channel_proj = nn.Linear(linear_units // 2, input_size) - self.activation = activation() - self.csgu = ConvolutionalSpatialGatingUnit( - input_size=linear_units, - kernel_size=kernel_size, - dropout=dropout, - use_linear_after_conv=use_linear_after_conv, - activation=gate_activation, - ) - - def forward(self, x): - """ - Arguments - ---------- - x: torch.Tensor -> (B, T, D) - - """ - x = self.activation(self.pre_channel_proj(x)) # (B, T, D) - x = self.csgu(x) # (B, T, D//2) - x = self.post_channel_proj(x) # (B, T, D) - - return x - - -class BranchformerEncoderLayer(nn.Module): - """This is an implementation of Branchformer encoder layer. - - Arguments - ---------- - d_model : int - The expected size of the input embedding. - nhead : int - Number of attention heads. - kernel_size : int, optional - Kernel size of convolution model. - kdim : int, optional - Dimension of the key. - vdim : int, optional - Dimension of the value. - activation: torch.nn.Module - Activation function used in each Conformer layer. - dropout : int, optional - Dropout for the encoder. - attention_type: str, optional - type of attention layer, e.g. SummaryMixing, regulaMHA for regular MultiHeadAttention. - csgu_linear_units: int, optional - Number of neurons in the hidden linear units of the CSGU Module. - gate_activation: torch.nn.Module, optional - Activation function used at the gate of the CSGU module. - use_linear_after_conv: bool, optional - If True, will apply a linear transformation of size input_size//2 - local_proj_out_dim: int, optional - The dimension of the output of the local projection branch. This - will be concatenated with the output of the summary branch - (default: 512). - summary_hid_dim: list [int], optional - A list of dimension specifying both the number of hidden layers - as well as the size of them in the summary projection branch - (default: [1024]). - summary_out_dim: int, optional - The dimension of the output of the summary projection branch. This - will be concatenated with the output of the local branch - (default: 1024). - activation: torch.nn.Module, optional - Torch module specifying the activation function used in both the local - and summary branches. - (default: torch.nn.GELU) - mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell - according to the definition of the article. "SummaryMixing-lite" removes the - local project branch. - - Example - ------- - >>> import torch - >>> x = torch.rand((8, 60, 512)) - >>> pos_embs = torch.rand((1, 2*60-1, 512)) - >>> net = BranchformerEncoderLayer(nhead=8, d_model=512, kernel_size=3) - >>> output = net(x, pos_embs=pos_embs) - >>> output[0].shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - d_model, - nhead, - kernel_size=31, - kdim=None, - vdim=None, - activation=nn.GELU, - dropout=0.0, - attention_type="SummaryMixing", - csgu_linear_units=3072, - gate_activation=nn.Identity, - use_linear_after_conv=False, - local_proj_hid_dim=[512], - local_proj_out_dim=512, - summary_hid_dim=[1024], - summary_out_dim=1024, - mode="SummaryMixing", - ): - super().__init__() - - self.attention_type = attention_type - self.mode = mode - - # If CNN only, no need for the attention branch and merging - if self.attention_type != "cnnonly": - if attention_type == "regularMHA": - self.mha_layer = MultiheadAttention( - nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim, - ) - self.merge_proj = torch.nn.Linear(d_model * 2, d_model) - elif attention_type == "RelPosMHAXL": - # transformerXL style positional encoding - self.mha_layer = RelPosMHAXL( - num_heads=nhead, - embed_dim=d_model, - dropout=dropout, - mask_pos_future=False, - ) - self.merge_proj = torch.nn.Linear(d_model * 2, d_model) - elif attention_type == "hypermixing": - self.mha_layer = HyperMixing( - input_output_dim=d_model, - hypernet_size=local_proj_hid_dim[0], - tied=False, - num_heads=nhead, - fix_tm_hidden_size=False, - ) - self.merge_proj = torch.nn.Linear(d_model * 2, d_model) - - elif attention_type == "SummaryMixing": - self.mha_layer = SummaryMixing( - enc_dim=d_model, - nhead=nhead, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - activation=activation, - mode=mode, - ) - self.merge_dnn_blocks = summary_hid_dim + [d_model] - self.merge_proj = VanillaNN( - input_shape=[None, None, local_proj_out_dim + summary_out_dim], - dnn_blocks=len(self.merge_dnn_blocks), - dnn_neurons=self.merge_dnn_blocks, - activation=activation, - ) - - self.norm_mhsa = LayerNorm(d_model) - - self.convolution_branch = ConvolutionBranch( - input_size=d_model, - kernel_size=kernel_size, - linear_units=csgu_linear_units, - activation=activation, - gate_activation=gate_activation, - dropout=dropout, - use_linear_after_conv=use_linear_after_conv, - ) - - self.norm_conv = LayerNorm(d_model) - self.dropout = nn.Dropout(dropout) - - def forward( - self, - x, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - pos_embs: Optional[torch.Tensor] = None, - ): - """ - Arguments - ---------- - x : torch.Tensor - The sequence to the encoder layer. - src_mask : torch.Tensor, optional - The mask for the src sequence. - src_key_padding_mask : torch.Tensor, optional - The mask for the src keys per batch. - pos_embs: torch.Tensor, torch.nn.Module, optional - Module or tensor containing the input sequence positional embeddings - """ - if self.attention_type == "cnnonly": - x2 = x - x2 = self._forward_cnn_branch(x2) - x = x + x2 - self_attn = None - else: - x1 = x - x2 = x - # Branch 1: Self-attention - x1, self_attn = self._forward_mha_branch( - x1, src_mask, src_key_padding_mask, pos_embs - ) - - # Branch 2: Convolutional gating MLP - # In ESPnet, masks are not used?! we do the same but warning! - x2 = self._forward_cnn_branch(x2) - - x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1))) - - return x, self_attn - - def _forward_cnn_branch( - self, x, - ): - """ - Arguments - ---------- - x : torch.Tensor - The sequence to the encoder layer. - """ - x = self.norm_conv(x) - x = self.convolution_branch(x) - - return self.dropout(x) - - def _forward_mha_branch( - self, - x, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - pos_embs: Optional[torch.Tensor] = None, - ): - """ - Arguments - ---------- - x : torch.Tensor - The sequence to the encoder layer. - src_mask : torch.Tensor, optional - The mask for the src sequence. - src_key_padding_mask : torch.Tensor, optional - The mask for the src keys per batch. - pos_embs: torch.Tensor, torch.nn.Module, optional - Module or tensor containing the input sequence positional embeddings - """ - - x = self.norm_mhsa(x) - - if self.attention_type == "SummaryMixing": - x = self.mha_layer(x, attention_mask=src_key_padding_mask) - self_attn = None - else: - x, self_attn = self.mha_layer( - x, - x, - x, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs, - ) - - return self.dropout(x), self_attn - - -class BranchformerEncoder(nn.Module): - """This class implements the Branchformer encoder. - - Arguments - --------- - num_layers : int - Number of layers. - d_model : int - Embedding dimension size. - nhead : int - Number of attention heads. - kernel_size : int, optional - Kernel size of convolution model. - kdim : int, optional - Dimension of the key. - vdim : int, optional - Dimension of the value. - activation: torch.nn.Module - Activation function used in each Confomer layer. - dropout : int, optional - Dropout for the encoder. - attention_type: str, optional - type of attention layer, e.g. SummaryMixing or regulaMHA for regular MultiHeadAttention. - csgu_linear_units: int, optional - Number of neurons in the hidden linear units of the CSGU Module. - gate_activation: torch.nn.Module, optional - Activation function used at the gate of the CSGU module. - use_linear_after_conv: bool, optional - If True, will apply a linear transformation of size input_size//2. - local_proj_out_dim: int, optional - The dimension of the output of the local projection branch. This - will be concatenated with the output of the summary branch - (default: 512). - summary_hid_dim: list [int], optional - A list of dimension specifying both the number of hidden layers - as well as the size of them in the summary projection branch - (default: [1024]). - summary_out_dim: int, optional - The dimension of the output of the summary projection branch. This - will be concatenated with the output of the local branch - (default: 1024). - activation: torch.nn.Module, optional - Torch module specifying the activation function used in both the local - and summary branches. - (default: torch.nn.GELU) - mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell - according to the definition of the article. "SummaryMixing-lite" removes the - local project branch. - - - Example - ------- - >>> import torch - >>> x = torch.rand((8, 60, 512)) - >>> pos_emb = torch.rand((1, 2*60-1, 512)) - >>> net = BranchformerEncoder(1, 512, 8) - >>> output, _ = net(x, pos_embs=pos_emb) - >>> output.shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - num_layers, - d_model, - nhead, - kernel_size=31, - kdim=None, - vdim=None, - activation=nn.GELU, - dropout=0.0, - attention_type="SummaryMixing", - csgu_linear_units=3072, - gate_activation=nn.Identity, - use_linear_after_conv=False, - local_proj_hid_dim=[512], - local_proj_out_dim=512, - summary_hid_dim=[1024], - summary_out_dim=1024, - mode="SummaryMixing", - ): - super().__init__() - - self.layers = torch.nn.ModuleList( - [ - BranchformerEncoderLayer( - nhead=nhead, - d_model=d_model, - kdim=kdim, - vdim=vdim, - dropout=dropout, - activation=activation, - kernel_size=kernel_size, - attention_type=attention_type, - csgu_linear_units=csgu_linear_units, - gate_activation=gate_activation, - use_linear_after_conv=use_linear_after_conv, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - mode=mode, - ) - for i in range(num_layers) - ] - ) - self.norm = LayerNorm(d_model, eps=1e-6) - self.attention_type = attention_type - - def forward( - self, - src, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - pos_embs: Optional[torch.Tensor] = None, - dynchunktrain_config=None, - ): - """ - Arguments - ---------- - src : torch.Tensor - The sequence to the encoder layer. - src_mask : torch.Tensor, optional - The mask for the src sequence. - src_key_padding_mask : torch.Tensor, optional - The mask for the src keys per batch. - pos_embs: torch.Tensor, torch.nn.Module, - Module or tensor containing the input sequence positional embeddings - If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) - where S is the sequence length, and E is the embedding dimension. - """ - assert ( - dynchunktrain_config is None - ), "Dynamic Chunk Training unsupported for this encoder" - - if self.attention_type == "RelPosMHAXL": - if pos_embs is None: - raise ValueError( - "The chosen attention type for the Branchformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" - ) - - output = src - attention_lst = [] - for enc_layer in self.layers: - output, attention = enc_layer( - output, - src_mask=src_mask, - src_key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs, - ) - attention_lst.append(attention) - output = self.norm(output) - - return output, attention_lst diff --git a/speechbrain/lobes/models/transformer/Conformer.py b/speechbrain/lobes/models/transformer/Conformer.py index 6d27adf..d9f9a47 100755 --- a/speechbrain/lobes/models/transformer/Conformer.py +++ b/speechbrain/lobes/models/transformer/Conformer.py @@ -1,10 +1,18 @@ """ SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. +This library contains SummaryMixing Conformer for supervised learning and self-supervised learning + +Usage: Install SpeechBrain + Copy this file under speechbrain/lobes/models/transformer + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: + Authors -------- -* Jianyuan Zhong 2020 -* Samuele Cornell 2021 -* Sylvain de Langen 2023 + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 """ from dataclasses import dataclass @@ -14,6 +22,7 @@ from typing import Optional, List import speechbrain as sb import warnings +import numpy as np from speechbrain.nnet.attention import ( @@ -22,7 +31,7 @@ PositionalwiseFeedForward, ) from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig -from speechbrain.nnet.hypermixing import HyperMixing +# from speechbrain.lobes.models.transformer.hypermixing import HyperMixing from speechbrain.nnet.normalization import LayerNorm from speechbrain.nnet.activations import Swish from speechbrain.nnet.summary_mixing import SummaryMixing @@ -31,10 +40,8 @@ @dataclass class ConformerEncoderLayerStreamingContext: """Streaming metadata and state for a `ConformerEncoderLayer`. - The multi-head attention and Dynamic Chunk Convolution require to save some left context that gets inserted as left padding. - See :class:`.ConvolutionModule` documentation for further details. """ @@ -53,7 +60,6 @@ class ConformerEncoderLayerStreamingContext: dcconv_left_context: Optional[torch.Tensor] = None """Left context to insert at the left of the convolution according to the Dynamic Chunk Convolution method. - Unlike `mha_left_context`, here the amount of frames to keep is fixed and inferred from the kernel size of the convolution module. """ @@ -69,7 +75,6 @@ class ConformerEncoderStreamingContext: class ConvolutionModule(nn.Module): """This is an implementation of convolution module in Conformer. - Arguments ---------- input_size : int @@ -86,7 +91,6 @@ class ConvolutionModule(nn.Module): Whether the convolution should be causal or not. dilation: int, optional Dilation factor for the non bottleneck conv layer. - Example ------- >>> import torch @@ -121,7 +125,9 @@ def __init__( self.layer_norm = nn.LayerNorm(input_size) self.bottleneck = nn.Sequential( # pointwise - nn.Conv1d(input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias), + nn.Conv1d( + input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias + ), nn.GLU(dim=1), ) # depthwise @@ -154,7 +160,6 @@ def forward( dynchunktrain_config: Optional[DynChunkTrainConfig] = None, ): """Applies the convolution to an input tensor `x`. - Arguments --------- x: torch.Tensor @@ -318,7 +323,6 @@ def forward( class ConformerEncoderLayer(nn.Module): """This is an implementation of Conformer encoder layer. - Arguments ---------- d_model : int @@ -355,7 +359,6 @@ class ConformerEncoderLayer(nn.Module): One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. - Example ------- >>> import torch @@ -392,7 +395,11 @@ def __init__( if attention_type == "regularMHA": self.mha_layer = MultiheadAttention( - nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim, + nhead=nhead, + d_model=d_model, + dropout=dropout, + kdim=kdim, + vdim=vdim, ) elif attention_type == "RelPosMHAXL": # transformerXL style positional encoding @@ -429,7 +436,10 @@ def __init__( self.ffn_module1 = nn.Sequential( nn.LayerNorm(d_model), PositionalwiseFeedForward( - d_ffn=d_ffn, input_size=d_model, dropout=dropout, activation=activation, + d_ffn=d_ffn, + input_size=d_model, + dropout=dropout, + activation=activation, ), nn.Dropout(dropout), ) @@ -437,7 +447,10 @@ def __init__( self.ffn_module2 = nn.Sequential( nn.LayerNorm(d_model), PositionalwiseFeedForward( - d_ffn=d_ffn, input_size=d_model, dropout=dropout, activation=activation, + d_ffn=d_ffn, + input_size=d_model, + dropout=dropout, + activation=activation, ), nn.Dropout(dropout), ) @@ -480,7 +493,9 @@ def forward( x = self.norm1(x) if self.attention_type == "SummaryMixing": - x = self.mha_layer(x, attention_mask=src_key_padding_mask) + x = self.mha_layer( + x, sum_mask=src_mask, src_padding_mask=src_key_padding_mask + ) self_attn = None else: x, self_attn = self.mha_layer( @@ -512,7 +527,6 @@ def forward_streaming( time. Relies on a mutable context object as initialized by `make_streaming_context` that should be used across chunks. Invoked by `ConformerEncoder.forward_streaming`. - Arguments --------- x : torch.Tensor @@ -540,7 +554,9 @@ def forward_streaming( # compute new MHA left context for the next call to our function if context.mha_left_context_size > 0: - context.mha_left_context = x[..., -context.mha_left_context_size :, :] + context.mha_left_context = x[ + ..., -context.mha_left_context_size :, : + ] # multi-head attention module skip = x @@ -549,7 +565,12 @@ def forward_streaming( x = self.mha_layer(x, attention_mask=None) else: x, self_attn = self.mha_layer( - x, x, x, attn_mask=None, key_padding_mask=None, pos_embs=pos_embs, + x, + x, + x, + attn_mask=None, + key_padding_mask=None, + pos_embs=pos_embs, ) x = x + skip @@ -561,7 +582,9 @@ def forward_streaming( x = torch.cat((context.dcconv_left_context, x), dim=1) # compute new DCConv left context for the next call to our function - context.dcconv_left_context = x[..., -self.convolution_module.padding :, :] + context.dcconv_left_context = x[ + ..., -self.convolution_module.padding :, : + ] # convolution module x = x + self.convolution_module(x) @@ -575,7 +598,6 @@ def forward_streaming( def make_streaming_context(self, mha_left_context_size: int): """Creates a blank streaming context for this encoding layer. - Arguments --------- mha_left_context_size : int @@ -589,7 +611,6 @@ def make_streaming_context(self, mha_left_context_size: int): class ConformerEncoder(nn.Module): """This class implements the Conformer encoder. - Arguments --------- num_layers : int @@ -628,8 +649,6 @@ class ConformerEncoder(nn.Module): One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. - - Example ------- >>> import torch @@ -659,6 +678,8 @@ def __init__( local_proj_out_dim=512, summary_hid_dim=[1024], mode="SummaryMixing", + layerdrop_prob=0.0, + output_hidden_states=False, ): super().__init__() @@ -686,6 +707,9 @@ def __init__( ) self.norm = LayerNorm(d_model, eps=1e-6) self.attention_type = attention_type + self.layerdrop_prob = layerdrop_prob + self.rng = np.random.default_rng() + self.output_hidden_states = output_hidden_states def forward( self, @@ -720,18 +744,34 @@ def forward( ) output = src + if self.layerdrop_prob > 0.0: + keep_probs = self.rng.random(len(self.layers)) + else: + keep_probs = None attention_lst = [] - for enc_layer in self.layers: - output, attention = enc_layer( - output, - src_mask=src_mask, - src_key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs, - dynchunktrain_config=dynchunktrain_config, - ) - attention_lst.append(attention) + if self.output_hidden_states: + hidden_lst = [] + for i, enc_layer in enumerate(self.layers): + if ( + not self.training + or self.layerdrop_prob == 0.0 + or keep_probs[i] > self.layerdrop_prob + ): + output, attention = enc_layer( + output, + src_mask=src_mask, + src_key_padding_mask=src_key_padding_mask, + pos_embs=pos_embs, + ) + attention_lst.append(attention) + if self.output_hidden_states: + hidden_lst.append(output) output = self.norm(output) + if self.output_hidden_states: + hidden_lst[-1] = output + return output, hidden_lst, attention_lst + return output, attention_lst def forward_streaming( @@ -744,7 +784,6 @@ def forward_streaming( DynamicChunkTraining-trained models), which is to be used at inference time. Relies on a mutable context object as initialized by `make_streaming_context` that should be used across chunks. - Arguments --------- src : torch.Tensor @@ -775,7 +814,6 @@ def forward_streaming( def make_streaming_context(self, mha_left_context_size: int): """Creates a blank streaming context for the encoder. - Arguments --------- mha_left_context_size : int @@ -795,7 +833,6 @@ def make_streaming_context(self, mha_left_context_size: int): class ConformerDecoderLayer(nn.Module): """This is an implementation of Conformer encoder layer. - Arguments ---------- d_model : int @@ -820,7 +857,6 @@ class ConformerDecoderLayer(nn.Module): Whether the convolutions should be causal or not. attention_type: str, optional type of attention layer, e.g. regulaMHA for regular MultiHeadAttention. - Example ------- >>> import torch @@ -855,7 +891,11 @@ def __init__( if attention_type == "regularMHA": self.mha_layer = MultiheadAttention( - nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim, + nhead=nhead, + d_model=d_model, + dropout=dropout, + kdim=kdim, + vdim=vdim, ) elif attention_type == "RelPosMHAXL": # transformerXL style positional encoding @@ -873,7 +913,10 @@ def __init__( self.ffn_module1 = nn.Sequential( nn.LayerNorm(d_model), PositionalwiseFeedForward( - d_ffn=d_ffn, input_size=d_model, dropout=dropout, activation=activation, + d_ffn=d_ffn, + input_size=d_model, + dropout=dropout, + activation=activation, ), nn.Dropout(dropout), ) @@ -881,7 +924,10 @@ def __init__( self.ffn_module2 = nn.Sequential( nn.LayerNorm(d_model), PositionalwiseFeedForward( - d_ffn=d_ffn, input_size=d_model, dropout=dropout, activation=activation, + d_ffn=d_ffn, + input_size=d_model, + dropout=dropout, + activation=activation, ), nn.Dropout(dropout), ) @@ -944,7 +990,6 @@ def forward( class ConformerDecoder(nn.Module): """This class implements the Transformer decoder. - Arguments ---------- num_layers: int @@ -971,8 +1016,6 @@ class ConformerDecoder(nn.Module): Whether the convolutions should be causal or not. attention_type: str, optional type of attention layer, e.g. regulaMHA for regular MultiHeadAttention. - - Example ------- >>> src = torch.rand((8, 60, 512)) @@ -1049,7 +1092,6 @@ def forward( Module or tensor containing the target sequence positional embeddings for each attention layer. pos_embs_src: torch.Tensor, torch.nn.Module, optional Module or tensor containing the source sequence positional embeddings for each attention layer. - """ output = tgt self_attns, multihead_attns = [], [] @@ -1068,4 +1110,4 @@ def forward( multihead_attns.append(multihead_attn) output = self.norm(output) - return output, self_attns, multihead_attns + return output, self_attns, multihead_attns \ No newline at end of file diff --git a/speechbrain/lobes/models/transformer/Transformer.py b/speechbrain/lobes/models/transformer/Transformer.py deleted file mode 100644 index 373efe3..0000000 --- a/speechbrain/lobes/models/transformer/Transformer.py +++ /dev/null @@ -1,1044 +0,0 @@ -""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. - -This library connects SummaryMixing to the standard SpeechBrain lobes for Transformer ASR. -Large parts of the code come from the SpeechBrain repository. - -Usage: Install SpeechBrain and copy this file under speechbrain/lobes/models/transformer/ -Source: https://arxiv.org/abs/2307.07421 - -Authors - * Titouan Parcollet 2023 - * Shucong Zhang 2023 - * Rogier van Dalen 2023 - * Sourav Bhattacharya 2023 -""" - -import math -import torch -import torch.nn as nn -import speechbrain as sb -from typing import Optional -import numpy as np - -from .Conformer import ConformerEncoder -from .Branchformer import BranchformerEncoder -from speechbrain.nnet.activations import Swish -from speechbrain.nnet.attention import RelPosEncXL -from speechbrain.nnet.CNN import Conv1d -from speechbrain.nnet.summary_mixing import SummaryMixing - - -class TransformerInterface(nn.Module): - """This is an interface for transformer model. - - Users can modify the attributes and define the forward function as - needed according to their own tasks. - - The architecture is based on the paper "Attention Is All You Need": - https://arxiv.org/pdf/1706.03762.pdf - - Arguments - ---------- - d_model: int - The number of expected features in the encoder/decoder inputs (default=512). - nhead: int - The number of heads in the multi-head attention models (default=8). - num_encoder_layers: int, optional - The number of encoder layers in1ì the encoder. - num_decoder_layers: int, optional - The number of decoder layers in the decoder. - dim_ffn: int, optional - The dimension of the feedforward network model hidden layer. - dropout: int, optional - The dropout value. - activation: torch.nn.Module, optional - The activation function for Feed-Forward Network layer, - e.g., relu or gelu or swish. - custom_src_module: torch.nn.Module, optional - Module that processes the src features to expected feature dim. - custom_tgt_module: torch.nn.Module, optional - Module that processes the src features to expected feature dim. - positional_encoding: str, optional - Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. - normalize_before: bool, optional - Whether normalization should be applied before or after MHA or FFN in Transformer layers. - Defaults to True as this was shown to lead to better performance and training stability. - kernel_size: int, optional - Kernel size in convolutional layers when Conformer is used. - bias: bool, optional - Whether to use bias in Conformer convolutional layers. - encoder_module: str, optional - Choose between Branchformer, Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer. - conformer_activation: torch.nn.Module, optional - Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. - branchformer_activation: torch.nn.Module, optional - Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module. - attention_type: str, optional - Type of attention layer used in all Transformer or Conformer layers. - e.g. SummaryMixing, regularMHA or RelPosMHA. - max_length: int, optional - Max length for the target and source sequence in input. - Used for positional encodings. - causal: bool, optional - Whether the encoder should be causal or not (the decoder is always causal). - If causal the Conformer convolutional layer is causal. - encoder_kdim: int, optional - Dimension of the key for the encoder. - encoder_vdim: int, optional - Dimension of the value for the encoder. - decoder_kdim: int, optional - Dimension of the key for the decoder. - decoder_vdim: int, optional - Dimension of the value for the decoder. - csgu_linear_units: int, optional - Number of neurons in the hidden linear units of the CSGU Module. - -> Branchformer - gate_activation: torch.nn.Module, optional - Activation function used at the gate of the CSGU module. - -> Branchformer - use_linear_after_conv: bool, optional - If True, will apply a linear transformation of size input_size//2. - -> Branchformer - local_proj_out_dim: int, optional - The dimension of the output of the local projection branch. This - will be concatenated with the output of the summary branch - (default: 512). - summary_hid_dim: list [int], optional - A list of dimension specifying both the number of hidden layers - as well as the size of them in the summary projection branch - (default: [1024]). - summary_out_dim: int, optional - The dimension of the output of the summary projection branch. This - will be concatenated with the output of the local branch - (default: 1024). - activation: torch.nn.Module, optional - Torch module specifying the activation function used in both the local - and summary branches. - (default: torch.nn.GELU) - mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell - according to the definition of the article. "SummaryMixing-lite" removes the - local project branch. - """ - - def __init__( - self, - d_model=512, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - d_ffn=2048, - dropout=0.1, - activation=nn.ReLU, - custom_src_module=None, - custom_tgt_module=None, - positional_encoding="fixed_abs_sine", - normalize_before=True, - kernel_size: Optional[int] = 31, - bias: Optional[bool] = True, - encoder_module: Optional[str] = "transformer", - conformer_activation: Optional[nn.Module] = Swish, - branchformer_activation: Optional[nn.Module] = nn.GELU, - attention_type: Optional[str] = "SummaryMixing", - max_length: Optional[int] = 2500, - causal: Optional[bool] = False, - encoder_kdim: Optional[int] = None, - encoder_vdim: Optional[int] = None, - decoder_kdim: Optional[int] = None, - decoder_vdim: Optional[int] = None, - csgu_linear_units: Optional[int] = 3072, - gate_activation: Optional[nn.Module] = nn.Identity, - use_linear_after_conv: Optional[bool] = False, - local_proj_hid_dim: Optional[list] = [512], - local_proj_out_dim: Optional[int] = 512, - summary_hid_dim: Optional[list] = [1024], - summary_out_dim: Optional[int] = 1024, - mode: Optional[str] = "SummaryMixing", - ): - super().__init__() - self.causal = causal - self.attention_type = attention_type - self.positional_encoding_type = positional_encoding - self.encoder_kdim = encoder_kdim - self.encoder_vdim = encoder_vdim - self.decoder_kdim = decoder_kdim - self.decoder_vdim = decoder_vdim - - assert attention_type in [ - "regularMHA", - "RelPosMHAXL", - "hypermixing", - "SummaryMixing", - ] - assert positional_encoding in ["fixed_abs_sine", None] - - assert ( - num_encoder_layers + num_decoder_layers > 0 - ), "number of encoder layers and number of decoder layers cannot both be 0!" - - if positional_encoding == "fixed_abs_sine": - self.positional_encoding = PositionalEncoding(d_model, max_length) - elif positional_encoding is None: - pass - # no positional encodings - - # overrides any other pos_embedding - if attention_type == "RelPosMHAXL": - self.positional_encoding = RelPosEncXL(d_model) - self.positional_encoding_decoder = PositionalEncoding(d_model, max_length) - - # initialize the encoder - if num_encoder_layers > 0: - if custom_src_module is not None: - self.custom_src_module = custom_src_module(d_model) - if encoder_module == "transformer": - self.encoder = TransformerEncoder( - nhead=nhead, - num_layers=num_encoder_layers, - d_ffn=d_ffn, - d_model=d_model, - dropout=dropout, - activation=activation, - normalize_before=normalize_before, - causal=self.causal, - attention_type=self.attention_type, - kdim=self.encoder_kdim, - vdim=self.encoder_vdim, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - mode=mode, - ) - elif encoder_module == "conformer": - self.encoder = ConformerEncoder( - nhead=nhead, - num_layers=num_encoder_layers, - d_ffn=d_ffn, - d_model=d_model, - dropout=dropout, - activation=conformer_activation, - kernel_size=kernel_size, - bias=bias, - causal=self.causal, - attention_type=self.attention_type, - ) - assert normalize_before, "normalize_before must be True for Conformer" - - assert ( - conformer_activation is not None - ), "conformer_activation must not be None" - elif encoder_module == "branchformer": - self.encoder = BranchformerEncoder( - nhead=nhead, - num_layers=num_encoder_layers, - d_model=d_model, - dropout=dropout, - activation=branchformer_activation, - kernel_size=kernel_size, - attention_type=self.attention_type, - csgu_linear_units=csgu_linear_units, - gate_activation=gate_activation, - use_linear_after_conv=use_linear_after_conv, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - mode=mode, - ) - - # initialize the decoder - if num_decoder_layers > 0: - if custom_tgt_module is not None: - self.custom_tgt_module = custom_tgt_module(d_model) - self.decoder = TransformerDecoder( - num_layers=num_decoder_layers, - nhead=nhead, - d_ffn=d_ffn, - d_model=d_model, - dropout=dropout, - activation=activation, - normalize_before=normalize_before, - causal=True, - attention_type="regularMHA", # always use regular attention in decoder - kdim=self.decoder_kdim, - vdim=self.decoder_vdim, - ) - - def forward(self, **kwags): - """Users should modify this function according to their own tasks.""" - raise NotImplementedError - - -class PositionalEncoding(nn.Module): - """This class implements the absolute sinusoidal positional encoding function. - - PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) - PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) - - Arguments - --------- - input_size: int - Embedding dimension. - max_len : int, optional - Max length of the input sequences (default 2500). - - Example - ------- - >>> a = torch.rand((8, 120, 512)) - >>> enc = PositionalEncoding(input_size=a.shape[-1]) - >>> b = enc(a) - >>> b.shape - torch.Size([1, 120, 512]) - """ - - def __init__(self, input_size, max_len=2500): - super().__init__() - if input_size % 2 != 0: - raise ValueError( - f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})" - ) - self.max_len = max_len - pe = torch.zeros(self.max_len, input_size, requires_grad=False) - positions = torch.arange(0, self.max_len).unsqueeze(1).float() - denominator = torch.exp( - torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size) - ) - - pe[:, 0::2] = torch.sin(positions * denominator) - pe[:, 1::2] = torch.cos(positions * denominator) - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) - - def forward(self, x): - """ - Arguments - --------- - x : tensor - Input feature shape (batch, time, fea) - """ - return self.pe[:, : x.size(1)].clone().detach() - - -class TransformerEncoderLayer(nn.Module): - """This is an implementation of self-attention encoder layer. - - Arguments - ---------- - d_ffn: int - The dimension of the feedforward network model hidden layer. - nhead: int - The number of heads in the multi-head attention models (default=8). - d_model: int - The number of expected features in the encoder/decoder inputs (default=512). - kdim: int, optional - Dimension of the key. - vdim: int, optional - Dimension of the value. - dropout: int, optional - The dropout value. - activation: torch.nn.Module, optional - The activation function for Feed-Forward Netowrk layer, - e.g., relu or gelu or swish. - normalize_before: bool, optional - Whether normalization should be applied before or after MHA or FFN in Transformer layers. - Defaults to True as this was shown to lead to better performance and training stability. - attention_type: str, optional - Type of attention layer used in all Transformer or Conformer layers. - e.g. SummaryMixing, regularMHA or RelPosMHA. - ffn_type: str - type of ffn: regularFFN/1dcnn - ffn_cnn_kernel_size_list: list of int - kernel size of 2 1d-convs if ffn_type is 1dcnn - causal: bool, optional - Whether the encoder should be causal or not (the decoder is always causal). - If causal the Conformer convolutional layer is causal. - mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell - according to the definition of the article. "SummaryMixing-lite" removes the - local project branch. - - Example - ------- - >>> import torch - >>> x = torch.rand((8, 60, 512)) - >>> net = TransformerEncoderLayer(512, 8, d_model=512) - >>> output = net(x) - >>> output[0].shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - d_ffn, - nhead, - d_model, - kdim=None, - vdim=None, - dropout=0.0, - activation=nn.ReLU, - normalize_before=False, - ffn_type="regularFFN", - ffn_cnn_kernel_size_list=[3, 3], - attention_type="SummaryMixing", - causal=False, - local_proj_hid_dim=[512], - local_proj_out_dim=512, - summary_hid_dim=[1024], - summary_out_dim=1024, - mode="SummaryMixing", - ): - super().__init__() - - self.attention_type = attention_type - - if attention_type == "regularMHA": - self.self_att = sb.nnet.attention.MultiheadAttention( - nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim, - ) - - elif attention_type == "RelPosMHAXL": - self.self_att = sb.nnet.attention.RelPosMHAXL( - d_model, nhead, dropout, mask_pos_future=causal - ) - elif attention_type == "hypermixing": - self.self_att = sb.nnet.hypermixing.HyperMixing( - input_output_dim=d_model, - hypernet_size=d_ffn, - tied=False, - num_heads=nhead, - fix_tm_hidden_size=False, - ) - elif attention_type == "SummaryMixing": - self.self_att = SummaryMixing( - enc_dim=d_model, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - activation=activation, - mode=mode, - ) - - if ffn_type == "regularFFN": - self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward( - d_ffn=d_ffn, input_size=d_model, dropout=dropout, activation=activation, - ) - elif ffn_type == "1dcnn": - self.pos_ffn = nn.Sequential( - Conv1d( - in_channels=d_model, - out_channels=d_ffn, - kernel_size=ffn_cnn_kernel_size_list[0], - padding="causal" if causal else "same", - ), - nn.ReLU(), - Conv1d( - in_channels=d_ffn, - out_channels=d_model, - kernel_size=ffn_cnn_kernel_size_list[1], - padding="causal" if causal else "same", - ), - ) - - self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - self.dropout1 = torch.nn.Dropout(dropout) - self.dropout2 = torch.nn.Dropout(dropout) - - self.normalize_before = normalize_before - self.pos_ffn_type = ffn_type - - def forward( - self, - src, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - pos_embs: Optional[torch.Tensor] = None, - ): - """ - Arguments - ---------- - src : torch.Tensor - The sequence to the encoder layer. - src_mask : torch.Tensor - The mask for the src query for each example in the batch. - src_key_padding_mask : torch.Tensor, optional - The mask for the src keys for each example in the batch. - """ - - if self.normalize_before: - src1 = self.norm1(src) - else: - src1 = src - - if self.attention_type == "SummaryMixing": - output = self.self_att(src1, attention_mask=src_key_padding_mask) - self_attn = None - else: - output, self_attn = self.self_att( - src1, - src1, - src1, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs, - ) - - # add & norm - src = src + self.dropout1(output) - if not self.normalize_before: - src = self.norm1(src) - - if self.normalize_before: - src1 = self.norm2(src) - else: - src1 = src - output = self.pos_ffn(src1) - - # add & norm - output = src + self.dropout2(output) - if not self.normalize_before: - output = self.norm2(output) - return output, self_attn - - -class TransformerEncoder(nn.Module): - """This class implements the transformer encoder. - - Arguments - --------- - num_layers : int - Number of transformer layers to include. - nhead : int - Number of attention heads. - d_ffn : int - Hidden size of self-attention Feed Forward layer. - d_model : int - The dimension of the input embedding. - kdim : int - Dimension for key (Optional). - vdim : int - Dimension for value (Optional). - dropout : float - Dropout for the encoder (Optional). - input_module: torch class - The module to process the source input feature to expected - feature dimension (Optional). - activation: torch.nn.Module, optional - The activation function for Feed-Forward Netowrk layer, - e.g., relu or gelu or swish. - normalize_before: bool, optional - Whether normalization should be applied before or after MHA or FFN in Transformer layers. - Defaults to True as this was shown to lead to better performance and training stability. - causal: bool, optional - Whether the encoder should be causal or not (the decoder is always causal). - If causal the Conformer convolutional layer is causal. - layerdrop_prob: float - The probability to drop an entire layer - attention_type: str, optional - Type of attention layer used in all Transformer or Conformer layers. - e.g. SummaryMixing, regularMHA or RelPosMHA. - ffn_type: str - type of ffn: regularFFN/1dcnn - ffn_cnn_kernel_size_list: list of int - conv kernel size of 2 1d-convs if ffn_type is 1dcnn - mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell - according to the definition of the article. "SummaryMixing-lite" removes the - local project branch. - - Example - ------- - >>> import torch - >>> x = torch.rand((8, 60, 512)) - >>> net = TransformerEncoder(1, 8, 512, d_model=512) - >>> output, _ = net(x) - >>> output.shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - num_layers, - nhead, - d_ffn, - input_shape=None, - d_model=None, - kdim=None, - vdim=None, - dropout=0.0, - activation=nn.ReLU, - normalize_before=False, - causal=False, - layerdrop_prob=0.0, - attention_type="regularMHA", - ffn_type="regularFFN", - ffn_cnn_kernel_size_list=[3, 3], - local_proj_hid_dim=[512], - local_proj_out_dim=512, - summary_hid_dim=[1024], - summary_out_dim=1024, - mode="SummaryMixing", - ): - super().__init__() - - self.layers = torch.nn.ModuleList( - [ - TransformerEncoderLayer( - d_ffn=d_ffn, - nhead=nhead, - d_model=d_model, - kdim=kdim, - vdim=vdim, - dropout=dropout, - activation=activation, - normalize_before=normalize_before, - causal=causal, - attention_type=attention_type, - ffn_type=ffn_type, - ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - mode=mode, - ) - for i in range(num_layers) - ] - ) - self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - self.layerdrop_prob = layerdrop_prob - self.rng = np.random.default_rng() - - def forward( - self, - src, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - pos_embs: Optional[torch.Tensor] = None, - dynchunktrain_config=None, - ): - """ - Arguments - ---------- - src : tensor - The sequence to the encoder layer (required). - src_mask : tensor - The mask for the src sequence (optional). - src_key_padding_mask : tensor - The mask for the src keys per batch (optional). - """ - assert ( - dynchunktrain_config is None - ), "Dynamic Chunk Training unsupported for this encoder" - - output = src - if self.layerdrop_prob > 0.0: - keep_probs = self.rng.random(len(self.layers)) - else: - keep_probs = None - attention_lst = [] - for i, enc_layer in enumerate(self.layers): - if ( - not self.training - or self.layerdrop_prob == 0.0 - or keep_probs[i] > self.layerdrop_prob - ): - output, attention = enc_layer( - output, - src_mask=src_mask, - src_key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs, - ) - - attention_lst.append(attention) - output = self.norm(output) - return output, attention_lst - - -class TransformerDecoderLayer(nn.Module): - """This class implements the self-attention decoder layer. - - Arguments - ---------- - d_ffn : int - Hidden size of self-attention Feed Forward layer. - nhead : int - Number of attention heads. - d_model : int - Dimension of the model. - kdim : int - Dimension for key (optional). - vdim : int - Dimension for value (optional). - dropout : float - Dropout for the decoder (optional). - - Example - ------- - >>> src = torch.rand((8, 60, 512)) - >>> tgt = torch.rand((8, 60, 512)) - >>> net = TransformerDecoderLayer(1024, 8, d_model=512) - >>> output, self_attn, multihead_attn = net(src, tgt) - >>> output.shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - d_ffn, - nhead, - d_model, - kdim=None, - vdim=None, - dropout=0.0, - activation=nn.ReLU, - normalize_before=False, - attention_type="regularMHA", - causal=None, - ): - super().__init__() - self.nhead = nhead - - if attention_type == "regularMHA": - self.self_attn = sb.nnet.attention.MultiheadAttention( - nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout, - ) - self.mutihead_attn = sb.nnet.attention.MultiheadAttention( - nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout, - ) - - elif attention_type == "RelPosMHAXL": - self.self_attn = sb.nnet.attention.RelPosMHAXL( - d_model, nhead, dropout, mask_pos_future=causal - ) - self.mutihead_attn = sb.nnet.attention.RelPosMHAXL( - d_model, nhead, dropout, mask_pos_future=causal - ) - - self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward( - d_ffn=d_ffn, input_size=d_model, dropout=dropout, activation=activation, - ) - - # normalization layers - self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - self.dropout1 = torch.nn.Dropout(dropout) - self.dropout2 = torch.nn.Dropout(dropout) - self.dropout3 = torch.nn.Dropout(dropout) - - self.normalize_before = normalize_before - - def forward( - self, - tgt, - memory, - tgt_mask=None, - memory_mask=None, - tgt_key_padding_mask=None, - memory_key_padding_mask=None, - pos_embs_tgt=None, - pos_embs_src=None, - ): - """ - Arguments - ---------- - tgt: tensor - The sequence to the decoder layer (required). - memory: tensor - The sequence from the last layer of the encoder (required). - tgt_mask: tensor - The mask for the tgt sequence (optional). - memory_mask: tensor - The mask for the memory sequence (optional). - tgt_key_padding_mask: tensor - The mask for the tgt keys per batch (optional). - memory_key_padding_mask: tensor - The mask for the memory keys per batch (optional). - """ - if self.normalize_before: - tgt1 = self.norm1(tgt) - else: - tgt1 = tgt - - # self-attention over the target sequence - tgt2, self_attn = self.self_attn( - query=tgt1, - key=tgt1, - value=tgt1, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - pos_embs=pos_embs_tgt, - ) - - # add & norm - tgt = tgt + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - if self.normalize_before: - tgt1 = self.norm2(tgt) - else: - tgt1 = tgt - - # multi-head attention over the target sequence and encoder states - - tgt2, multihead_attention = self.mutihead_attn( - query=tgt1, - key=memory, - value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - pos_embs=pos_embs_src, - ) - - # add & norm - tgt = tgt + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) - - if self.normalize_before: - tgt1 = self.norm3(tgt) - else: - tgt1 = tgt - - tgt2 = self.pos_ffn(tgt1) - - # add & norm - tgt = tgt + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - - return tgt, self_attn, multihead_attention - - -class TransformerDecoder(nn.Module): - """This class implements the Transformer decoder. - - Arguments - ---------- - nhead : int - Number of attention heads. - d_ffn : int - Hidden size of self-attention Feed Forward layer. - d_model : int - Dimension of the model. - kdim : int, optional - Dimension for key (Optional). - vdim : int, optional - Dimension for value (Optional). - dropout : float, optional - Dropout for the decoder (Optional). - - Example - ------- - >>> src = torch.rand((8, 60, 512)) - >>> tgt = torch.rand((8, 60, 512)) - >>> net = TransformerDecoder(1, 8, 1024, d_model=512) - >>> output, _, _ = net(src, tgt) - >>> output.shape - torch.Size([8, 60, 512]) - """ - - def __init__( - self, - num_layers, - nhead, - d_ffn, - d_model, - kdim=None, - vdim=None, - dropout=0.0, - activation=nn.ReLU, - normalize_before=False, - causal=False, - attention_type="regularMHA", - ): - super().__init__() - self.layers = torch.nn.ModuleList( - [ - TransformerDecoderLayer( - d_ffn=d_ffn, - nhead=nhead, - d_model=d_model, - kdim=kdim, - vdim=vdim, - dropout=dropout, - activation=activation, - normalize_before=normalize_before, - causal=causal, - attention_type=attention_type, - ) - for _ in range(num_layers) - ] - ) - self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) - - def forward( - self, - tgt, - memory, - tgt_mask=None, - memory_mask=None, - tgt_key_padding_mask=None, - memory_key_padding_mask=None, - pos_embs_tgt=None, - pos_embs_src=None, - ): - """ - Arguments - ---------- - tgt : tensor - The sequence to the decoder layer (required). - memory : tensor - The sequence from the last layer of the encoder (required). - tgt_mask : tensor - The mask for the tgt sequence (optional). - memory_mask : tensor - The mask for the memory sequence (optional). - tgt_key_padding_mask : tensor - The mask for the tgt keys per batch (optional). - memory_key_padding_mask : tensor - The mask for the memory keys per batch (optional). - """ - output = tgt - self_attns, multihead_attns = [], [] - for dec_layer in self.layers: - output, self_attn, multihead_attn = dec_layer( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - pos_embs_tgt=pos_embs_tgt, - pos_embs_src=pos_embs_src, - ) - self_attns.append(self_attn) - multihead_attns.append(multihead_attn) - output = self.norm(output) - - return output, self_attns, multihead_attns - - -class NormalizedEmbedding(nn.Module): - """This class implements the normalized embedding layer for the transformer. - - Since the dot product of the self-attention is always normalized by sqrt(d_model) - and the final linear projection for prediction shares weight with the embedding layer, - we multiply the output of the embedding by sqrt(d_model). - - Arguments - --------- - d_model: int - The number of expected features in the encoder/decoder inputs (default=512). - vocab: int - The vocab size. - - Example - ------- - >>> emb = NormalizedEmbedding(512, 1000) - >>> trg = torch.randint(0, 999, (8, 50)) - >>> emb_fea = emb(trg) - """ - - def __init__(self, d_model, vocab): - super().__init__() - self.emb = sb.nnet.embedding.Embedding( - num_embeddings=vocab, embedding_dim=d_model, blank_id=0 - ) - self.d_model = d_model - - def forward(self, x): - """ Processes the input tensor x and returns an output tensor.""" - return self.emb(x) * math.sqrt(self.d_model) - - -def get_key_padding_mask(padded_input, pad_idx): - """Creates a binary mask to prevent attention to padded locations. - We suggest using get_mask_from_lengths instead of this function. - Arguments - ---------- - padded_input: int - Padded input. - pad_idx: - idx for padding element. - - Example - ------- - >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) - >>> get_key_padding_mask(a, pad_idx=0) - tensor([[False, False, True], - [False, False, True], - [False, False, True]]) - """ - if len(padded_input.shape) == 4: - bz, time, ch1, ch2 = padded_input.shape - padded_input = padded_input.reshape(bz, time, ch1 * ch2) - - key_padded_mask = padded_input.eq(pad_idx).to(padded_input.device) - - # if the input is more than 2d, mask the locations where they are silence - # across all channels - if len(padded_input.shape) > 2: - key_padded_mask = key_padded_mask.float().prod(dim=-1).bool() - return key_padded_mask.detach() - - return key_padded_mask.detach() - - -def get_lookahead_mask(padded_input): - """Creates a binary mask for each sequence which maskes future frames. - - Arguments - --------- - padded_input: torch.Tensor - Padded input tensor. - - Example - ------- - >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) - >>> get_lookahead_mask(a) - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0.]]) - """ - seq_len = padded_input.shape[1] - mask = ( - torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device)) == 1 - ).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask.detach().to(padded_input.device) - - -def get_mask_from_lengths(lengths, max_len=None): - """Creates a binary mask from sequence lengths - Arguments - --------- - lengths: torch.Tensor - A tensor of sequence lengths - max_len: int (Optional) - Maximum sequence length, defaults to None. - Returns - ------- - mask: torch.Tensor - the mask where padded elements are set to True. - Then one can use tensor.masked_fill_(mask, 0) for the masking. - Example - ------- - >>> lengths = torch.tensor([3, 2, 4]) - >>> get_mask_from_lengths(lengths) - tensor([[False, False, False, True], - [False, False, True, True], - [False, False, False, False]]) - """ - if max_len is None: - max_len = torch.max(lengths).item() - seq_range = torch.arange(max_len, device=lengths.device, dtype=lengths.dtype) - return ~(seq_range.unsqueeze(0) < lengths.unsqueeze(1)) diff --git a/speechbrain/lobes/models/transformer/TransformerASR.py b/speechbrain/lobes/models/transformer/TransformerASR.py deleted file mode 100755 index a7a3f76..0000000 --- a/speechbrain/lobes/models/transformer/TransformerASR.py +++ /dev/null @@ -1,665 +0,0 @@ -""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. - -This library connects SummaryMixing to the standard SpeechBrain lobes for Transformer-based ASR. -Large parts of the code come from the SpeechBrain repository. - -Usage: Install SpeechBrain and copy this file under speechbrain/lobes/models/transformer/ -Source: https://arxiv.org/abs/2307.07421 - -Authors - * Titouan Parcollet 2023 - * Shucong Zhang 2023 - * Rogier van Dalen 2023 - * Sourav Bhattacharya 2023 -""" -from dataclasses import dataclass -import torch # noqa 42 -from torch import nn -from typing import Any, Optional -from speechbrain.nnet.linear import Linear -from speechbrain.nnet.containers import ModuleList -from speechbrain.lobes.models.transformer.Transformer import ( - TransformerInterface, - get_lookahead_mask, - get_key_padding_mask, - NormalizedEmbedding, -) -from speechbrain.nnet.activations import Swish -from speechbrain.dataio.dataio import length_to_mask -from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig - - -@dataclass -class TransformerASRStreamingContext: - """Streaming metadata and state for a `TransformerASR` instance.""" - - dynchunktrain_config: DynChunkTrainConfig - """Dynamic Chunk Training configuration holding chunk size and context size - information.""" - - encoder_context: Any - """Opaque encoder context information. It is constructed by the encoder's - `make_streaming_context` method and is passed to the encoder when using - `encode_streaming`. - """ - - -def make_transformer_src_mask( - src: torch.Tensor, - causal: bool = False, - dynchunktrain_config: Optional[DynChunkTrainConfig] = None, -) -> Optional[torch.Tensor]: - """Prepare the source transformer mask that restricts which frames can - attend to which frames depending on causal or other simple restricted - attention methods. - - Arguments - --------- - src: torch.Tensor - The source tensor to build a mask from. The contents of the tensor are - not actually used currently; only its shape and other metadata (e.g. - device). - - causal: bool - Whether strict causality shall be used. Frames will not be able to - attend to any future frame. - - dynchunktrain_config: DynChunkTrainConfig, optional - Dynamic Chunk Training configuration. This implements a simple form of - chunkwise attention. Incompatible with `causal`.""" - - if causal: - assert dynchunktrain_config is None - return get_lookahead_mask(src) - - if dynchunktrain_config is not None: - # init a mask that masks nothing by default - # 0 == no mask, 1 == mask - src_mask = torch.zeros( - (src.shape[1], src.shape[1]), device=src.device, dtype=torch.bool, - ) - - # The following is not really the sole source used to implement this, - # but it helps introduce the concept. - # ref: Unified Streaming and Non-streaming Two-pass End-to-end Model - # for Speech Recognition - # https://arxiv.org/pdf/2012.05481.pdf - - timesteps = src.size(1) - - # mask the future at the right of each chunk - for t in range(timesteps): - # if we have a chunk size of 8 then: - # for 0..7 -> mask 8.. - # for 8..15 -> mask 16.. - # etc. - next_chunk_index = (t // dynchunktrain_config.chunk_size) + 1 - visible_range = next_chunk_index * dynchunktrain_config.chunk_size - src_mask[t, visible_range:] = True - - # mask the past at the left of each chunk (accounting for left context) - # only relevant if using left context - if not dynchunktrain_config.is_infinite_left_context(): - for t in range(timesteps): - chunk_index = t // dynchunktrain_config.chunk_size - chunk_first_t = chunk_index * dynchunktrain_config.chunk_size - - left_context_frames = ( - dynchunktrain_config.left_context_size - * dynchunktrain_config.chunk_size - ) - - frame_remaining_context = max(0, chunk_first_t - left_context_frames,) - - # end range is exclusive, so there is no off-by-one here - src_mask[t, :frame_remaining_context] = True - - return src_mask - - return None - - -def make_transformer_src_tgt_masks( - src, - tgt=None, - wav_len=None, - pad_idx=0, - causal: bool = False, - dynchunktrain_config: Optional[DynChunkTrainConfig] = None, -): - """This function generates masks for training the transformer model, - opiniated for an ASR context with encoding masks and, optionally, decoding - masks (if specifying `tgt`). - - Arguments - --------- - src : tensor - The sequence to the encoder (required). - tgt : tensor - The sequence to the decoder. - pad_idx : int - The index for token (default=0). - causal: bool - Whether strict causality shall be used. See `make_asr_src_mask` - dynchunktrain_config: DynChunkTrainConfig, optional - Dynamic Chunk Training configuration. See `make_asr_src_mask` - """ - src_key_padding_mask = None - - # mask out audio beyond the length of audio for each batch - if wav_len is not None: - abs_len = torch.round(wav_len * src.shape[1]) - src_key_padding_mask = ~length_to_mask(abs_len).bool() - - # mask out the source - src_mask = make_transformer_src_mask( - src, causal=causal, dynchunktrain_config=dynchunktrain_config - ) - - # If no decoder in the transformer... - if tgt is not None: - tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) - tgt_mask = get_lookahead_mask(tgt) - else: - tgt_key_padding_mask = None - tgt_mask = None - - return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask - - -class TransformerASR(TransformerInterface): - """This is an implementation of transformer model for ASR. - - The architecture is based on the paper "Attention Is All You Need": - https://arxiv.org/pdf/1706.03762.pdf - - Arguments - ---------- - tgt_vocab: int - Size of vocabulary. - input_size: int - Input feature size. - d_model : int, optional - Embedding dimension size. - (default=512). - nhead : int, optional - The number of heads in the multi-head attention models (default=8). - num_encoder_layers : int, optional - The number of sub-encoder-layers in the encoder (default=6). - num_decoder_layers : int, optional - The number of sub-decoder-layers in the decoder (default=6). - dim_ffn : int, optional - The dimension of the feedforward network model (default=2048). - dropout : int, optional - The dropout value (default=0.1). - activation : torch.nn.Module, optional - The activation function of FFN layers. - Recommended: relu or gelu (default=relu). - positional_encoding: str, optional - Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. - normalize_before: bool, optional - Whether normalization should be applied before or after MHA or FFN in Transformer layers. - Defaults to True as this was shown to lead to better performance and training stability. - kernel_size: int, optional - Kernel size in convolutional layers when Conformer is used. - bias: bool, optional - Whether to use bias in Conformer convolutional layers. - encoder_module: str, optional - Choose between Conformer and Transformer for the encoder. The decoder is fixed to be a Transformer. - conformer_activation: torch.nn.Module, optional - Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. - branchformer_activation: torch.nn.Module, optional - Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module. - attention_type: str, optional - Type of attention layer used in all Transformer or Conformer layers. - e.g. SummaryMixing, regularMHA or RelPosMHA. - max_length: int, optional - Max length for the target and source sequence in input. - Used for positional encodings. - causal: bool, optional - Whether the encoder should be causal or not (the decoder is always causal). - If causal the Conformer convolutional layer is causal. - csgu_linear_units: int, optional - Number of neurons in the hidden linear units of the CSGU Module. - -> Branchformer - gate_activation: torch.nn.Module, optional - Activation function used at the gate of the CSGU module. - -> Branchformer - use_linear_after_conv: bool, optional - If True, will apply a linear transformation of size input_size//2. - -> Branchformer - local_proj_out_dim: int, optional - The dimension of the output of the local projection branch. This - will be concatenated with the output of the summary branch - (default: 512). - summary_hid_dim: list [int], optional - A list of dimension specifying both the number of hidden layers - as well as the size of them in the summary projection branch - (default: [1024]). - summary_out_dim: int, optional - The dimension of the output of the summary projection branch. This - will be concatenated with the output of the local branch - (default: 1024). - activation: torch.nn.Module, optional - Torch module specifying the activation function used in both the local - and summary branches. - (default: torch.nn.GELU) - mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell - according to the definition of the article. "SummaryMixing-lite" removes the - local project branch. - - Example - ------- - >>> src = torch.rand([8, 120, 512]) - >>> tgt = torch.randint(0, 720, [8, 120]) - >>> net = TransformerASR( - ... 720, 512, 512, 1, 1, 1, 1024, activation=torch.nn.GELU - ... ) - >>> enc_out, dec_out = net.forward(src, tgt) - >>> enc_out.shape - torch.Size([8, 120, 512]) - >>> dec_out.shape - torch.Size([8, 120, 512]) - """ - - def __init__( - self, - tgt_vocab, - input_size, - d_model=512, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - d_ffn=2048, - dropout=0.1, - activation=nn.ReLU, - positional_encoding="fixed_abs_sine", - normalize_before=False, - kernel_size: Optional[int] = 31, - bias: Optional[bool] = True, - encoder_module: Optional[str] = "transformer", - conformer_activation: Optional[nn.Module] = Swish, - branchformer_activation: Optional[nn.Module] = nn.GELU, - attention_type: Optional[str] = "SummaryMixing", - max_length: Optional[int] = 2500, - causal: Optional[bool] = True, - csgu_linear_units: Optional[int] = 3072, - gate_activation: Optional[nn.Module] = nn.Identity, - use_linear_after_conv: Optional[bool] = False, - local_proj_hid_dim: Optional[list] = [512], - local_proj_out_dim: Optional[int] = 512, - summary_hid_dim: Optional[list] = [1024], - summary_out_dim: Optional[int] = 1024, - mode: Optional[str] = "SummaryMixing", - ): - super().__init__( - d_model=d_model, - nhead=nhead, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - d_ffn=d_ffn, - dropout=dropout, - activation=activation, - positional_encoding=positional_encoding, - normalize_before=normalize_before, - kernel_size=kernel_size, - bias=bias, - encoder_module=encoder_module, - conformer_activation=conformer_activation, - branchformer_activation=branchformer_activation, - attention_type=attention_type, - max_length=max_length, - causal=causal, - csgu_linear_units=csgu_linear_units, - gate_activation=gate_activation, - use_linear_after_conv=use_linear_after_conv, - local_proj_hid_dim=local_proj_hid_dim, - local_proj_out_dim=local_proj_out_dim, - summary_hid_dim=summary_hid_dim, - summary_out_dim=summary_out_dim, - mode=mode, - ) - - self.num_decoder_layers = num_decoder_layers - self.num_encoder_layers = num_encoder_layers - - self.custom_src_module = ModuleList( - Linear( - input_size=input_size, n_neurons=d_model, bias=True, combine_dims=False, - ), - torch.nn.Dropout(dropout), - ) - - if num_decoder_layers > 0: - self.custom_tgt_module = ModuleList(NormalizedEmbedding(d_model, tgt_vocab)) - - # reset parameters using xavier_normal_ - self._init_params() - - def forward(self, src, tgt=None, wav_len=None, pad_idx=0): - """ - Arguments - ---------- - src : torch.Tensor - The sequence to the encoder. - tgt : torch.Tensor - The sequence to the decoder. If None, only the encoder is run. - wav_len: torch.Tensor, optional - Torch Tensor of shape (batch, ) containing the relative length to padded length for each example. - pad_idx : int, optional - The index for token (default=0). - """ - - # reshpae the src vector to [Batch, Time, Fea] is a 4d vector is given - if src.ndim == 4: - bz, t, ch1, ch2 = src.shape - src = src.reshape(bz, t, ch1 * ch2) - - ( - src_key_padding_mask, - tgt_key_padding_mask, - src_mask, - tgt_mask, - ) = make_transformer_src_tgt_masks( - src, tgt, wav_len, causal=self.causal, pad_idx=pad_idx - ) - - src = self.custom_src_module(src) - # add pos encoding to queries if are sinusoidal ones else - if ( - self.attention_type == "hypermixing" - or self.attention_type == "SummaryMixing" - ): - pos_embs_encoder = None - elif self.attention_type == "RelPosMHAXL": - pos_embs_encoder = self.positional_encoding(src) - elif self.positional_encoding_type == "fixed_abs_sine": - src = src + self.positional_encoding(src) # add the encodings here - pos_embs_encoder = None - - encoder_out, _ = self.encoder( - src=src, - src_mask=src_mask, - src_key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs_encoder, - ) - - tgt = self.custom_tgt_module(tgt) - - if self.attention_type == "RelPosMHAXL": - # use standard sinusoidal pos encoding in decoder - tgt = tgt + self.positional_encoding_decoder(tgt) - pos_embs_encoder = None # self.positional_encoding(src) - pos_embs_target = None - elif ( - self.positional_encoding_type == "fixed_abs_sine" - or self.attention_type == "hypermixing" - ): - tgt = tgt + self.positional_encoding(tgt) - pos_embs_target = None - pos_embs_encoder = None - - decoder_out, _, _ = self.decoder( - tgt=tgt, - memory=encoder_out, - memory_mask=None, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=src_key_padding_mask, - pos_embs_tgt=pos_embs_target, - pos_embs_src=pos_embs_encoder, - ) - - return encoder_out, decoder_out - - @torch.no_grad() - def decode(self, tgt, encoder_out, enc_len=None): - """This method implements a decoding step for the transformer model. - - Arguments - --------- - tgt : torch.Tensor - The sequence to the decoder. - encoder_out : torch.Tensor - Hidden output of the encoder. - enc_len : torch.LongTensor - The actual length of encoder states. - """ - tgt_mask = get_lookahead_mask(tgt) - src_key_padding_mask = None - if enc_len is not None: - src_key_padding_mask = (1 - length_to_mask(enc_len)).bool() - - tgt = self.custom_tgt_module(tgt) - if self.attention_type == "RelPosMHAXL": - # we use fixed positional encodings in the decoder - tgt = tgt + self.positional_encoding_decoder(tgt) - encoder_out = encoder_out + self.positional_encoding_decoder(encoder_out) - # pos_embs_target = self.positional_encoding(tgt) - pos_embs_encoder = None # self.positional_encoding(src) - pos_embs_target = None - elif ( - self.positional_encoding_type == "fixed_abs_sine" - or self.attention_type == "hypermixing" - ): - tgt = tgt + self.positional_encoding(tgt) # add the encodings here - pos_embs_target = None - pos_embs_encoder = None - - prediction, self_attns, multihead_attns = self.decoder( - tgt, - encoder_out, - tgt_mask=tgt_mask, - memory_key_padding_mask=src_key_padding_mask, - pos_embs_tgt=pos_embs_target, - pos_embs_src=pos_embs_encoder, - ) - return prediction, multihead_attns[-1] - - def encode( - self, - src, - wav_len=None, - pad_idx=0, - dynchunktrain_config: Optional[DynChunkTrainConfig] = None, - ): - """ - Encoder forward pass - - Arguments - ---------- - src : torch.Tensor - The sequence to the encoder. - wav_len: torch.Tensor, optional - Torch Tensor of shape (batch, ) containing the relative length to padded length for each example. - """ - # reshape the src vector to [Batch, Time, Fea] if a 4d vector is given - if src.dim() == 4: - bz, t, ch1, ch2 = src.shape - src = src.reshape(bz, t, ch1 * ch2) - - (src_key_padding_mask, _, src_mask, _,) = make_transformer_src_tgt_masks( - src, - None, - wav_len, - pad_idx=pad_idx, - causal=self.causal, - dynchunktrain_config=dynchunktrain_config, - ) - - src = self.custom_src_module(src) - if self.attention_type == "hypermixing": - pos_embs_source = None - elif self.attention_type == "RelPosMHAXL": - pos_embs_source = self.positional_encoding(src) - elif self.positional_encoding_type == "fixed_abs_sine": - src = src + self.positional_encoding(src) - pos_embs_source = None - - encoder_out, _ = self.encoder( - src=src, - src_mask=src_mask, - src_key_padding_mask=src_key_padding_mask, - pos_embs=pos_embs_source, - dynchunktrain_config=dynchunktrain_config, - ) - return encoder_out - - def encode_streaming(self, src, context: TransformerASRStreamingContext): - """ - Streaming encoder forward pass - - Arguments - --------- - src : torch.Tensor - The sequence (chunk) to the encoder. - - context : TransformerASRStreamingContext - Mutable reference to the streaming context. This holds the state - needed to persist across chunk inferences and can be built using - `make_streaming_context`. This will get mutated by this function. - - Returns - ------- - Encoder output for this chunk. - - Example - ------- - >>> import torch - >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR - >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig - >>> net = TransformerASR( - ... tgt_vocab=100, - ... input_size=64, - ... d_model=64, - ... nhead=8, - ... num_encoder_layers=1, - ... num_decoder_layers=0, - ... d_ffn=128, - ... attention_type="RelPosMHAXL", - ... positional_encoding=None, - ... encoder_module="conformer", - ... normalize_before=True, - ... causal=False, - ... ) - >>> ctx = net.make_streaming_context( - ... DynChunkTrainConfig(16, 24), - ... encoder_kwargs={"mha_left_context_size": 24}, - ... ) - >>> src1 = torch.rand([8, 16, 64]) - >>> src2 = torch.rand([8, 16, 64]) - >>> out1 = net.encode_streaming(src1, ctx) - >>> out1.shape - torch.Size([8, 16, 64]) - >>> ctx.encoder_context.layers[0].mha_left_context.shape - torch.Size([8, 16, 64]) - >>> out2 = net.encode_streaming(src2, ctx) - >>> out2.shape - torch.Size([8, 16, 64]) - >>> ctx.encoder_context.layers[0].mha_left_context.shape - torch.Size([8, 24, 64]) - >>> combined_out = torch.concat((out1, out2), dim=1) - >>> combined_out.shape - torch.Size([8, 32, 64]) - """ - - if src.dim() == 4: - bz, t, ch1, ch2 = src.shape - src = src.reshape(bz, t, ch1 * ch2) - - # HACK: our problem here is that the positional_encoding is computed - # against the size of our source tensor, but we only know how many left - # context frames we're injecting to the encoder within the encoder - # context. - # so this workaround does just that. - # - # i'm not sure how this would be best refactored, but an option would be - # to let the encoder get the pos embedding itself and have a way to - # cache it. - # - # additionally, positional encoding functions take in a whole source - # tensor just to get its attributes (size, device, type) but this is - # sort of silly for the embeddings that don't need one. - # so we craft a dummy empty (uninitialized) tensor to help... - known_left_context = context.encoder_context.layers[0].mha_left_context - if known_left_context is None: - pos_encoding_dummy = src - else: - target_shape = list(src.shape) - target_shape[-2] += known_left_context.shape[-2] - pos_encoding_dummy = torch.empty(size=target_shape).to(src) - - src = self.custom_src_module(src) - if self.attention_type == "RelPosMHAXL": - pos_embs_source = self.positional_encoding(pos_encoding_dummy) - - elif self.positional_encoding_type == "fixed_abs_sine": - src = src + self.positional_encoding(pos_encoding_dummy) - pos_embs_source = None - - encoder_out, _ = self.encoder.forward_streaming( - src=src, pos_embs=pos_embs_source, context=context.encoder_context - ) - return encoder_out - - def make_streaming_context( - self, dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={} - ): - """Creates a blank streaming context for this transformer and its - encoder. - - Arguments - --------- - dynchunktrain_config : DynChunkTrainConfig - Runtime chunkwise attention configuration. - - encoder_kwargs : dict - Parameters to be forward to the encoder's `make_streaming_context`. - Metadata required for the encoder could differ depending on the - encoder. - """ - return TransformerASRStreamingContext( - dynchunktrain_config=dynchunktrain_config, - encoder_context=self.encoder.make_streaming_context(**encoder_kwargs,), - ) - - def _init_params(self): - for p in self.parameters(): - if p.dim() > 1: - torch.nn.init.xavier_normal_(p) - - -class EncoderWrapper(nn.Module): - """This is a wrapper of any ASR transformer encoder. By default, the - TransformerASR .forward() function encodes and decodes. With this wrapper - the .forward() function becomes .encode() only. - - Important: The TransformerASR class must contain a .encode() function. - - Arguments - ---------- - transformer : sb.lobes.models.TransformerInterface - A Transformer instance that contains a .encode() function. - - Example - ------- - >>> src = torch.rand([8, 120, 512]) - >>> tgt = torch.randint(0, 720, [8, 120]) - >>> net = TransformerASR( - ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU - ... ) - >>> encoder = EncoderWrapper(net) - >>> enc_out = encoder(src) - >>> enc_out.shape - torch.Size([8, 120, 512]) - """ - - def __init__(self, transformer, *args, **kwargs): - super().__init__(*args, **kwargs) - self.transformer = transformer - - def forward(self, x, wav_lens=None, pad_idx=0, **kwargs): - """ Processes the input tensor x and returns an output tensor.""" - x = self.transformer.encode(x, wav_lens, pad_idx, **kwargs,) - return x diff --git a/speechbrain/lobes/models/wav2vec.py b/speechbrain/lobes/models/wav2vec.py new file mode 100644 index 0000000..f3f4838 --- /dev/null +++ b/speechbrain/lobes/models/wav2vec.py @@ -0,0 +1,569 @@ +""" SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. + +This library contains SummaryMixing wav2vec 2.0. for pre-training and downstream tasks + +Usage: Install SpeechBrain + Copy this file under speechbrain/lobes/models + +SummaryMixing: https://arxiv.org/abs/2307.07421 +SummaryMixing SSL: + +Authors + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 +""" + +import logging +import torch +import torch.nn.functional as F +import torch.nn as nn +import random +import numpy as np + +from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding +from speechbrain.utils.data_utils import batch_pad_right +from speechbrain.dataio.dataio import length_to_mask +from speechbrain.lobes.models.convolution import ConvolutionFrontEnd +from speechbrain.nnet.CNN import Conv1d +from speechbrain.nnet.normalization import LayerNorm +from speechbrain.nnet.quantisers import GumbelVectorQuantizer + +from speechbrain.processing.features import InputNormalization +from speechbrain.lobes.features import Fbank + +logger = logging.getLogger() + + +class W2VLatentExtractor(nn.Module): + """Convolution based feature extractor from raw audio. + Channel numbers increasing is based on https://arxiv.org/abs/2109.06870 + Arguments + --------- + out_channels : list of ints + Out channels of convolutional layers. + kernel_sizes : list of ints + Kernels of convolutional layers. + strides : list of ints + Strides of convolutional layers. + dropout : float + Dropout of CNN. + Example + ------- + >>> extractor = W2VLatentExtractor() + >>> inputs = torch.rand(10, 5000) + >>> outputs = extractor(inputs) + >>> outputs.shape + torch.Size([10, 14, 512]) + """ + + def __init__( + self, + out_channels=[512, 512, 512, 512, 512, 512, 512], + kernel_sizes=[11, 3, 3, 3, 3, 3, 3], + strides=[5, 2, 2, 2, 2, 2, 2], + dropout=0.0, + conv_init="kaiming", + input_dim=None, + pretrained_path=None, + ): + super().__init__() + + assert len(out_channels) == len(kernel_sizes) == len(strides) + + num_blocks = len(out_channels) + self.kernel_sizes = kernel_sizes + self.strides = strides + self.out_dim = out_channels[-1] + # ! Note this does conv, norm, gelu, dropout. while fairseq does conv, dropout, norm, gelu + # Also fairseq layernorm is forced to fp32 + if input_dim is None: + inp_shape = ( + None, + 16000, + ) + else: + inp_shape = (None, 16000, input_dim) + self.extractor = ConvolutionFrontEnd( + inp_shape, + num_blocks=num_blocks, + num_layers_per_block=1, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + strides=strides, + dilations=[1] * num_blocks, + residuals=[False] * num_blocks, + conv_module=Conv1d, + activation=nn.GELU, + norm=LayerNorm, + dropout=dropout, + conv_bias=False, + padding="valid", + conv_init=conv_init, + ) + self.norm = nn.LayerNorm(out_channels[-1]) + + if pretrained_path: + ckpt = torch.load(pretrained_path) + self.load_state_dict(ckpt) + + def forward(self, x, normalize_signal=True): + """ Calculates latents from audio input. + """ + if normalize_signal: + x = F.layer_norm(x, x.shape[1:]) + + latents = self.extractor(x) + return self.norm(latents) + + def get_output_lengths(self, input_lengths: torch.LongTensor): + """ Calculates output lengths for given input lengths. """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + for kernel_size, stride in zip(self.kernel_sizes, self.strides): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + return input_lengths.to(torch.long) + + +class W2VTargetQuantiser(nn.Module): + """ Wraps ``nnet.quantiser.GumbelVectorQuantizer``, see for documentation on + arguments. + + Example + ------- + >>> quantiser = W2VTargetQuantiser() + >>> inputs = torch.rand(10, 12, 512) + >>> output, meta = quantiser(inputs) + >>> output.shape + torch.Size([10, 12, 256]) + """ + + def __init__( + self, + in_dim=512, + out_dim=256, + quantiser=GumbelVectorQuantizer, + num_vars=320, + temperature_decay=(2.0, 0.25, 0.999995,), + ): + super().__init__() + self.quantiser = quantiser( + in_dim, num_vars, temperature_decay, 2, out_dim + ) + self.proj = nn.Linear(out_dim, out_dim) + + def forward(self, x): + """ Returns quantised targets plus meta information. """ + x = self.quantiser(x) + targets = self.proj(x["x"]) + code_perplex = x["code_perplexity"] + prob_perplex = x["prob_perplex"] + num_vars = x["num_vars"] + temp = x["temp"] + diversity_loss = (num_vars - prob_perplex) / num_vars + meta = { + "diversity_loss": diversity_loss, + "code_perplex": code_perplex, + "prob_perplex": prob_perplex, + "num_vars": num_vars, + "temp": temp, + } + return targets, meta + + +class EncoderWrapper(nn.Module): + """A wrapper that adds positional information, + masks the input and then runs the latent encoder. + Arguments + --------- + in_dim : int + Last dimension of input tensor. + embedding_dim : int + Dimension to project input to and that the latent encoder will use. + latent_encoder : torch.nn.module + Initialized latent encoder object. + positional_encoding : torch.nn.module + Uninitialized nn.module for adding positional information, will use ``embedding_dim``. + dropout_encoder_input : float + Dropout on encoder input. + + Example + ------- + >>> from speechbrain.lobes.models.transformer.Transformer import TransformerEncoder + >>> encoder = TransformerEncoder(d_model=768, num_layers=4, nhead=4, d_ffn=1024) + >>> wrapper = EncoderWrapper(1024, 768, encoder) + >>> inputs = torch.rand(10, 12, 1024) + >>> outputs = wrapper(inputs) + >>> outputs["embeddings"].shape + torch.Size([10, 12, 768]) + """ + + def __init__( + self, + in_dim, + embedding_dim, + latent_encoder, + positional_encoding=PositionalEncoding, + dropout_encoder_input=0.05, + output_hidden_states=False, + pretrained_path=None, + ): + super().__init__() + self.input_projector = nn.Linear(in_dim, embedding_dim) + self.latent_encoder = latent_encoder + self.positional_encoding = positional_encoding( + embedding_dim, max_len=3500 + ) + self.dropout_encoder_input = nn.Dropout(dropout_encoder_input) + self.mask_emb = nn.Parameter( + torch.FloatTensor(embedding_dim).uniform_(), requires_grad=True + ) + self.output_hidden_states = output_hidden_states + if pretrained_path: + ckpt = torch.load(pretrained_path) + self.load_state_dict(ckpt) + + def forward( + self, latents, wav_lens=None, padding_mask=None, mask=None, + ): + """ + Arguments + --------- + latents : torch.Tensor, shape (B, T, C) + Batch of latent representations (AKA frames) output from latent extractor. + wav_lens : torch.Tensor, shape (B,) + The actual (unpadded) relative lengths for each sample of the batch (0 + neg_indcs[neg_indcs >= targets] += 1 + + neg_indcs = neg_indcs + torch.arange(B).unsqueeze(1) * high + y = y.view(-1, C) + negs = y[neg_indcs.view(-1)] + negs = negs.view(B, T, num_neg, C).permute(2, 0, 1, 3) # to N, B, T, C + return negs + + +def w2v_mask_collate_fn(samples_lst, get_out_len_fn, mask_prob, mask_length, hop_length=None): + """ This creates a batch from a list of samples and also creates + the boolean mask that will be used to mask the inputs of the latent + encoder. To create the mask we need to know the output shape after the + latent extractor, therefore the argument `get_out_len_fn`. + One could also create masks per sample (when loading the audio file) and + then collate them but at that time one doesn't know the length of the + shortest sample in the batch (which determines the number of masked frames) + so it's better this way. + + Arguments + --------- + samples_lst : list + List of samples returned by the audio_pipeline. + get_out_len_fn : function + Function that calculates length of sample after it passes through feature extractor. + mask_prob : float + Approximate percentage of frames to mask. + mask_length : int + Number of contiguous frames that will be masked. + + Returns + ------- + wavs_padded : torch.Tensor, shape (B, T) + Audio arrays with right-sided padding. + wav_lens : torch.Tensor, shape (B,) + For each sample the percentage of the array that is not padding. + mask : torch.Tensor, shape (B, T) + Boolean mask to mask frames. + """ + wav_lst, latent_length_lst = [], [] + ids = [] + + for sample in samples_lst: + ids.append(sample["id"]) + sig = sample["sig"] + wav_lst.append(sig) + if hop_length is not None: + latent_length = get_out_len_fn( + torch.as_tensor(sig.size(-1)), hop_length + ) + else: + latent_length = get_out_len_fn(torch.as_tensor(sig.size(-1))) + latent_length_lst.append(latent_length.item()) + bs = len(wav_lst) + wavs_padded, wav_lens = batch_pad_right(wav_lst) + + batch_time_len = max(latent_length_lst) + mask = compute_mask( + (bs, batch_time_len,), latent_length_lst, mask_prob, mask_length + ) + return ( + torch.as_tensor(wavs_padded), + torch.as_tensor(wav_lens), + torch.as_tensor(mask, dtype=torch.bool), + ) + +class WeightedSSLModel(torch.nn.Module): + """This lobe enables the integration of use of weighted sum representations + from different layers in a SSL encoder. + + The model can be used as a fixed feature extractor for SSL benchmarking. It + will download automatically the model from HuggingFace or use a local path. + + More details in recipes/SSL_benchmark + + Arguments + --------- + hub : str + HuggingFace hub name: e.g "facebook/wav2vec2-large-lv60" + num_layers: int + Number of internal layers: e.g 13 for "Base" models. + layernorm: bool + Whether layer representations should be layernormed before sum + Example + ------- + >>> inputs = torch.rand([10, 600]) + >>> model_hub = "facebook/wav2vec2-base-960h" + >>> num_layers = 13 + >>> model = WeightedSSLModel(model_hub, num_layers) + >>> outputs = model(inputs) + """ + + def __init__( + self, + num_layers, + pretrained_path, + latent_encoder, + CNN, + dropout=0.0, + conv_init="kaiming", + in_dim=512, + embedding_dim=768, + positional_encoding=PositionalEncoding, + dropout_encoder_input=0.0, + output_hidden_states=True, + layernorm=False, + full_tune=False, + sample_rate=16000, + n_fft=400, + n_mels=80, + hop_length=10, + win_length=25,): + super().__init__() + + self.compute_features = Fbank( + sample_rate=16000, + n_fft=400, + n_mels=80, + hop_length=10, + win_length=25, + ) + self.inputnorm = InputNormalization(norm_type="sentence") + self.latent_extractor = CNN + + self.encoder_wrapper = EncoderWrapper( + in_dim, + embedding_dim, + latent_encoder, + positional_encoding, + dropout_encoder_input, + output_hidden_states, + ) + + + + latent_extractor_path = f"{pretrained_path}/CNN.ckpt" + latent_encoder_path = f"{pretrained_path}/latent_encoder.ckpt" + + latent_extractor_ckpt = torch.load(latent_extractor_path) + latent_encoder_ckpt = torch.load(latent_encoder_path) + + self.latent_extractor.load_state_dict(latent_extractor_ckpt) + self.encoder_wrapper.load_state_dict(latent_encoder_ckpt) + self.output_hidden_states = output_hidden_states + + self.num_layers = num_layers + # Initializing the learnable weights + zero_init = torch.cat([torch.zeros(self.num_layers)]) + self.weights = torch.nn.Parameter(zero_init, requires_grad=True) + self.layernorm = layernorm + self.full_tune = full_tune + + def forward(self, wav, wav_lens=None): + """This method outputs a weighted sum of the layers representations of the SSL encoder + Arguments + --------- + wav : tensor + The wavs + """ + # SB mel + if not self.full_tune: + with torch.no_grad(): + latents = self.compute_features(wav) + latents = self.inputnorm(latents, wav_lens).detach() + latents = self.latent_extractor(latents) + latents = latents.view(latents.shape[0], latents.shape[1], -1) + + feats = self.encoder_wrapper(latents, wav_lens=wav_lens)[ + "embeddings" + ] + + hidden_states = torch.stack(feats, dim=0).detach() + else: + with torch.no_grad(): + latents = self.compute_features(wav) + latents = self.inputnorm(latents, wav_lens).detach() + latents = self.latent_extractor(latents) + latents = latents.view(latents.shape[0], latents.shape[1], -1) + + feats = self.encoder_wrapper(latents, wav_lens=wav_lens)[ + "embeddings" + ] + + hidden_states = torch.stack(feats, dim=0) + + + # First dimension should be equal to the number of layers in the hparams + assert ( + self.num_layers == hidden_states.shape[0] + ), f"Num layers {self.num_layers} not equal to num hidden states {hidden_states.shape[0]}" + norm_weights = torch.nn.functional.softmax(self.weights, dim=-1) + # Layernorming the layers representations if asked + if self.layernorm: + hidden_states = [ + F.layer_norm(t, (t.shape[-1],)) for t in hidden_states + ] + # Summing the weighted layers + weighted_feats = hidden_states[0] * norm_weights[0] + for i in range(1, len(hidden_states)): + weighted_feats += hidden_states[i] * norm_weights[i] + + return weighted_feats \ No newline at end of file diff --git a/speechbrain/nnet/summary_mixing.py b/speechbrain/nnet/summary_mixing.py index de0e941..d5439ab 100644 --- a/speechbrain/nnet/summary_mixing.py +++ b/speechbrain/nnet/summary_mixing.py @@ -1,20 +1,18 @@ """ SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0. - This library provides the basic building blocks for SummaryMixing. - Usage: Install SpeechBrain and copy this file under speechbrain/nnet/ Source: https://arxiv.org/abs/2307.07421 - Authors - * Titouan Parcollet 2023 - * Shucong Zhang 2023 - * Rogier van Dalen 2023 - * Sourav Bhattacharya 2023 + * Titouan Parcollet 2023, 2024 + * Shucong Zhang 2023, 2024 + * Rogier van Dalen 2023, 2024 + * Sourav Bhattacharya 2023, 2024 """ import math import torch import logging +import numpy as np import torch.nn as nn from typing import Optional from speechbrain.lobes.models.VanillaNN import VanillaNN @@ -26,7 +24,6 @@ class SummaryMixing(nn.Module): """ This class implements SummaryMixing as defined in https://arxiv.org/abs/2307.07421 - Arguments --------- enc_dim: int @@ -44,21 +41,19 @@ class SummaryMixing(nn.Module): summary_hid_dim: list [int], optional A list of dimension specifying both the number of hidden layers as well as the size of them in the summary projection branch - (default: [512]). + (default: [1024]). summary_out_dim: int, optional The dimension of the output of the summary projection branch. This will be concatenated with the output of the local branch - (default: 512). + (default: 1024). activation: torch.nn.Module, optional Torch module specifying the activation function used in both the local and summary branches. (default: torch.nn.GELU) mode: string, optional - One of "SummaryMixing" or "SummaryMixing-lite". Changes the SummaryMixing cell + One of "SummaryMixing", "SummaryMixing-lite" or "SummaryMixing-expdecay". Changes the SummaryMixing cell according to the definition of the article. "SummaryMixing-lite" removes the local project branch. - - Example ------- >>> x = torch.rand(2,4,8) @@ -74,14 +69,19 @@ def __init__( nhead, local_proj_hid_dim: Optional[list] = [512], local_proj_out_dim: Optional[int] = 512, - summary_hid_dim: Optional[list] = [512], - summary_out_dim: Optional[int] = 512, + summary_hid_dim: Optional[list] = [1024], + summary_out_dim: Optional[int] = 1024, activation: Optional[nn.Module] = nn.GELU, mode: Optional[str] = "SummaryMixing", ): super(SummaryMixing, self).__init__() - if mode not in ["SummaryMixing", "SummaryMixing-lite"]: + if mode not in [ + "SummaryMixing", + "SummaryMixing-lite", + "SummaryMixing-expdecay", + "SummaryMixing-fast", + ]: raise ValueError( "The SummaryMixing mode should either be 'SummaryMixing' or 'SummaryMixing-lite'" ) @@ -90,13 +90,19 @@ def __init__( self.local_proj_out_dim = local_proj_out_dim self.summary_hid_dim = summary_hid_dim self.summary_out_dim = summary_out_dim + self.summary_reshaped_dim = int(np.sqrt(summary_out_dim)) self.enc_dim = enc_dim self.activation = activation() self.local_dnn_blocks = local_proj_hid_dim + [local_proj_out_dim] self.summary_dnn_blocks = summary_hid_dim + [summary_out_dim] self.mode = mode + self.dropout = nn.Dropout(0.1) + self.time_dropout = nn.Dropout(0.1) - if self.mode == "SummaryMixing": + if ( + self.mode == "SummaryMixing" + or self.mode == "SummaryMixing-expdecay" + ): self.local_proj = VanillaNN( input_shape=[None, None, enc_dim], @@ -116,83 +122,197 @@ def __init__( self.local_norm = nn.LayerNorm(local_proj_out_dim) self.summary_norm = nn.LayerNorm(summary_out_dim) - self.summary_proj = VanillaNN( - input_shape=[None, None, enc_dim], - dnn_blocks=len(self.summary_dnn_blocks), - dnn_neurons=self.summary_dnn_blocks, - activation=activation, - n_split=nhead, - ) + if self.mode == "SummaryMixing-fast": + self.global_proj = VanillaNN( + input_shape=[None, None, enc_dim], + dnn_blocks=1, + dnn_neurons=self.local_proj_out_dim * 2, + activation=activation, + n_split=1, + ) + + self.summary_local_merging = VanillaNN( + input_shape=[None, None, self.local_proj_out_dim * 2], + dnn_blocks=1, + dnn_neurons=[summary_out_dim], + activation=activation, + ) + + # self.summary_norm = nn.LayerNorm(summary_out_dim) + + else: + self.summary_proj = VanillaNN( + input_shape=[None, None, enc_dim], + dnn_blocks=len(self.summary_dnn_blocks), + dnn_neurons=self.summary_dnn_blocks, + activation=activation, + n_split=nhead, + ) + + if self.mode == "SummaryMixing-expdecay": + self.decay_constant = nn.Parameter( + data=torch.tensor(0.995), requires_grad=False + ) self.apply(self._init_parameters) - def forward(self, x, attention_mask=None): + def forward(self, x, sum_mask=None, src_padding_mask=None): """ This function simply goes forward! - Arguments --------- x: torch.Tensor The expected shape is the standard SpeechBrain one - [Batch, Time, Features] - attention_mask: torch.Tensor - (B, S) to pad before summarizing in time. + sum_mask: torch.Tensor + (Time, Time) per time step mask that can be used to compute different sum between time-step. + this can be useful for streaming, for instance, where each time step has a limited context. + src_padding_mask: torch.Tensor + (Batch, Time) corresponding to padding. We avoid padding when summarizing in time. """ - - if attention_mask is not None: - attention_mask = torch.logical_not(attention_mask).unsqueeze(-1).float() + if src_padding_mask is not None: + src_padding_mask = ( + torch.logical_not(src_padding_mask).unsqueeze(-1).float() + ) else: - attention_mask = torch.ones((x.shape[0], x.shape[1])).unsqueeze(-1).float() + src_padding_mask = ( + torch.ones((x.shape[0], x.shape[1])).unsqueeze(-1).float() + ) - if self.mode == "SummaryMixing": - return self._forward_mixing(x, attention_mask) + if sum_mask is not None: + sum_mask = torch.logical_not(sum_mask).float() + + if ( + self.mode == "SummaryMixing" + or self.mode == "SummaryMixing-expdecay" + ): + return self._forward_mixing(x, sum_mask, src_padding_mask) + elif self.mode == "SummaryMixing-fast": + return self._forward_mixing_fast(x, sum_mask, src_padding_mask) elif self.mode == "SummaryMixing-lite": - return self._forward_avgonly(x, attention_mask) + return self._forward_avgonly(x, sum_mask, src_padding_mask) - def _forward_mixing(self, x, attention_mask): + def _forward_mixing(self, x, sum_mask, src_padding_mask): """ Perform full SummaryMixing. - Arguments --------- x: torch.Tensor The expected shape is the standard SpeechBrain one - [Batch, Time, Features] - attention_mask: torch.Tensor - (B, S) to pad before summarizing in time. + sum_mask: torch.Tensor + (Time, Time) per time step mask that can be used to compute different sum between time-step. + this can be useful for streaming, for instance, where each time step has a limited context. + src_padding_mask: torch.Tensor + (Batch, Time) corresponding to padding. We avoid padding when summarizing in time. """ B, T, F = x.shape # f() (Eq. 1b) - local_summary = self.local_norm(self.local_proj(x) * attention_mask) + local_summary = self.local_norm(self.local_proj(x) * src_padding_mask) # s() (Eq. 2 and 1c) - time_summary = self.summary_proj(x) * attention_mask + time_summary = self.summary_proj(x) * src_padding_mask - # We normalise by real length by counting masking - time_summary = self.summary_norm( - torch.sum(time_summary, dim=1) / torch.sum(attention_mask, dim=1) - ) - time_summary = time_summary.unsqueeze(1).repeat(1, T, 1) + if self.mode == "SummaryMixing-expdecay": + sum_mask = self._laplace_weights( + T, self.decay_constant, sum_mask, x.device + ) + + if sum_mask is None: + + # We normalise by real length by counting masking + time_summary = self.summary_norm( + torch.sum(time_summary, dim=1) + / torch.sum(src_padding_mask, dim=1) + ) + time_summary = time_summary.unsqueeze(1).repeat(1, T, 1) + + else: + + # We must do a masked sum. The mask is [Time, Time] and the features are [B,T,F] + # We therefore can do a matmul between [B,F,T] and [Time,Time].T to obtain [B,F,T] that we can re-transpose. + # We need to be careful when dividing as padding is not included in sum_mask. We need to build the intersection + # of both mask to know the actual real number of elements excluding padding. + + # full_mask_with_pad = torch.matmul(sum_mask, src_padding_mask) + + time_summary = self.summary_norm( + torch.matmul(time_summary.mT, sum_mask.T).mT + ) return self.summary_local_merging( - torch.cat([local_summary, time_summary], dim=-1) + self.time_dropout(torch.cat([local_summary, time_summary], dim=-1)) ) - def _forward_avgonly(self, x, attention_mask): - """ Perform SummaryMixing-lite. + def _forward_mixing_fast(self, x, sum_mask, src_padding_mask): + """ Perform full SummaryMixing. + Arguments + --------- + x: torch.Tensor + The expected shape is the standard SpeechBrain one - [Batch, Time, Features] + sum_mask: torch.Tensor + (Time, Time) per time step mask that can be used to compute different sum between time-step. + this can be useful for streaming, for instance, where each time step has a limited context. + src_padding_mask: torch.Tensor + (Batch, Time) corresponding to padding. We avoid padding when summarizing in time. + """ + + B, T, F = x.shape + + global_proj = self.global_proj(x) * src_padding_mask + split_global_proj = torch.split( + global_proj, self.local_proj_out_dim, dim=-1 + ) + + # split_global_proj[0] = local projection + # split_global_proj[1] = summary projection + + if sum_mask is None: + # We normalise by real length by counting masking + time_summary = torch.sum(split_global_proj[1], dim=1) / torch.sum( + src_padding_mask, dim=1 + ) + time_summary = time_summary.unsqueeze(1).repeat(1, T, 1) + + else: + + # We must do a masked sum. The mask is [Time, Time] and the features are [B,T,F] + # We therefore can do a matmul between [B,F,T] and [Time,Time].T to obtain [B,F,T] that we can re-transpose. + # We need to be careful when dividing as padding is not included in sum_mask. We need to build the intersection + # of both mask to know the actual real number of elements excluding padding. + + full_mask_with_pad = torch.matmul(sum_mask, src_padding_mask) + time_summary = ( + torch.matmul(split_global_proj[1].mT, sum_mask.T).mT + / full_mask_with_pad + ) + + return self.dropout( + self.summary_local_merging( + torch.cat([split_global_proj[0], time_summary], dim=-1) + ) + ) + + def _forward_avgonly(self, x, sum_mask, src_padding_mask): + """ Perform SummaryMixing-lite. Arguments --------- x: torch.Tensor The expected shape is the standard SpeechBrain one - [Batch, Time, Features] - attention_mask: torch.Tensor - (B, S) to pad before summarizing in time. + sum_mask: torch.Tensor + (Time, Time) per time step mask that can be used to compute different sum between time-step. + this can be useful for streaming, for instance, where each time step has a limited context. + src_padding_mask: torch.Tensor + (Batch, Time) corresponding to padding. We avoid padding when summarizing in time. """ B, T, F = x.shape # s() We just do the mean over time # Then we repeat the output matrix T times along the time axis - time_summary = self.summary_proj(x) * attention_mask - time_summary = torch.sum(time_summary, dim=1) / torch.sum(attention_mask, dim=1) + time_summary = self.summary_proj(x) * src_padding_mask + time_summary = torch.sum(time_summary, dim=1) / torch.sum( + src_padding_mask, dim=1 + ) time_summary = time_summary.unsqueeze(1).expand(-1, T, -1) return time_summary @@ -201,6 +321,56 @@ def _init_parameters(self, module): if isinstance(module, nn.Linear): torch.nn.init.zeros_(module.bias) + def _laplace_weights( + self, + size: int, + decay_constant, + binary_mask: Optional[torch.Tensor] = None, + device="cpu", + normalise=False, + ): + """ + Return a square matrix with the diagonal entries the maximum one in each row + and the entries left and right decaying exponentially. + This is like a discrete Laplacian distribution. + If normalise is set to True, in each row, the entries add up to 1. + Arguments + --------- + size: int + The height and width of the returned matrix. + decay_constant: float + The exponential decay per position. + This must be a positive value, and will normally be less than 1. + binary_mask: torch.Tensor + A binary mask applied before the rows are normalised. + device: str + Torch device to copy the generated masks to. + """ + + # Fill a matrix with integers indicating how far away each element is from + # the diagonal. + horizontal_distance_to_diagonal = torch.abs( + torch.arange(size) - torch.arange(size).unsqueeze(-1) + ).to(device) + + # A Laplacian-like shape with the correct decay, but where the diagonal + # elements are all 1. + absolute_laplacian = torch.exp( + horizontal_distance_to_diagonal * torch.log(decay_constant) + ) + + if binary_mask is not None: + absolute_laplacian = absolute_laplacian * binary_mask + + if normalise: + # Normalise each row. + normalised = absolute_laplacian / torch.sum( + absolute_laplacian, dim=1, keepdim=True + ) + return normalised + + return absolute_laplacian + def _reset_parameters(self): # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see