Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scripts for SummaryMixing SSL #9

Open
wants to merge 1 commit into
base: SummaryMixing_w2v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 15 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,27 @@
# SummaryMixing for SpeechBrain v1.0
*Halve your VRAM requirements and train 30% faster any speech model achieving equivalents or better Word Error Rates and SLU accuracies with SummaryMixing Conformers and Branchformers.*

## !! A word about using SummaryMixing with SpeechBrain V1.0 !!

The main branch of this repository will keep tracking the latest version of SpeechBrain available. Unfortunately the results reported in our [publication](https://arxiv.org/abs/2307.07421) and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated
[branch](https://github.com/SamsungLabs/SummaryMixing/tree/speechbrain_v0.5) that contains the code compatible with SpeechBrain v0.5!
# SummaryMixing wav2vec 2.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not erase the previous Readme, it should be combined.

We equip wav2vec 2.0 (w2v2) with SummaryMixing, our linear-time alternative to the quadratic cost self-attention. Compared to self-attention based w2v2, SummaryMixing based w2v2 greatly reduces the cost for self-supervised pre-training and gives better or the same level performance on downstream tasks.

## In brief
This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the [publication](https://arxiv.org/abs/2307.07421) for further details). The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) toolkit with version 0.5 -- copy and paste is all you need to start using SummaryMixing in your setup. If you wish to run with the latest version of SpeechBrain (v1.0+), please go to the main branch of this repository. SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear).
This repository implements SummaryMixing w2v2. The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) copy and paste is all you need to start using SummaryMixing in your setup.

## A glance at SummaryMixing

SummaryMixing is a linear-time alternative to self-attention (SA) for speech processing models such as Transformers, Conformers or Branchformers. Instead of computing pair-wise scores between tokens (leading to quadratic-time complexity for SA), it summarises a whole utterance with mean over vectors for all time steps. SummaryMixing is based on the recent [findings](https://arxiv.org/pdf/2207.02971.pdf) demonstrating that self-attention could be useless for speech recognition as the attention weights of trained ASR systems are almost uniformly distributed accross the tokens composing a sequence. SummaryMixing also is a generalisation of the recent [HyperMixer](https://arxiv.org/abs/2203.03691) and [HyperConformer](https://arxiv.org/abs/2305.18281) to better and simpler mixing functions. In a SummaryMixing cell, that takes the same inputs and produces the same outputs than self-attention, contributions from each time step are first transformed and then averaged globally before being fed back to each time step. This is visible in Figure 1 in the [article](https://arxiv.org/abs/2307.07421). Therefore, the time-complexity is reduced to linear.

In this branch, we use SummaryMixing for self-supervised learning by equipping w2v2 with SummaryMixing. For a detailed description, please refer to this [article]()

### A few results

A SummaryMixing-equipped Conformer outperforms a self-attention equivalent model on Librispeech test-clean (2.1% vs 2.3%) and test-other (5.1% vs 5.4%). This is done with a 30% training reduction as well as less than half of the memory budget (from 46GB to 21GB). Such gains are also visible with CommonVoice, AISHELL-1 and Tedlium2. This gain is also visible at decoding time as the real-time factor remains stable (does not increase) with the sentence length for a SummaryMixing Branchformer while the same model with self-attention would see its RTF following a quadratic increase. The SpeechBrain configuration files in this repository can reproduce these numbers.

The following Table gives an idea of the results observed with Librispeech. More results on CommonVoice, AISHELL, Tedlium, SLURP, and Google Speech Command are available in the [article](https://arxiv.org/abs/2307.07421).
| Encoder | Variant | Dev-clean | Test-clean | Test-other | GPU | VRAM |
|------------------|----------------------|--------------------|---------------------|---------------------|----------------|---------------|
| | | **WER \%** | **WER \%** | **WER \%** | **hours** | **GB** |
| ContextNet | N.A. | 3.3 | 2.3 | 5.9 | 160 | 25 |
| Transformer | Self-attention | 3.3 | 2.3 | 5.5 | 129 | 40 |
| Conformer | Self-attention | 2.8 | 2.3 | 5.4 | 137 | 46 |
| Branchformer | Self-attention | 2.9 | 2.2 | 5.1 | 132 | 45 |
| | CNN Only | 3.1 | 2.4 | 5.7 | 83 | 22 |
| | HyperMixer | 3.1 | 2.3 | 5.6 | 126 | 30 |
| | FastFormer | 3.0 | 2.2 | 5.4 | 96 | 23 |
| | **Proposed** |
| Conformer | SummaryMixing | 2.8 | 2.1 | 5.1 | 98 | 21 |
| Branchformers | SummaryMixing-lite | 3.0 | 2.2 | 5.2 | 98 | 23 |
| | SummaryMixing | 2.9 | 2.2 | 5.1 | 105 | 26 |
| | +Summary Decoder | 3.1 | 2.3 | 5.3 | 104 | 26 |


<img src="summarymixing.png" alt="RTF performance" style="height: 400px;"/>
In the experiment of the [article](), SummaryMixing-equipped w2v2 reduces the pre-training time and memory budget by 18% and 23%, respectively, with better or equivalent results for the downstream automatic speech recognition, intent classification, emotion recognition, and automatic speaker verification. The following Table gives the results of SummaryMixing-based and attention-based SSL models on CommonVoice Welsh ASR and SLURP intent classification. For the results of other downstream tasks please refer to the [article](). The SpeechBrain configuration files in this repository can reproduce these numbers.


| Context Encoder | Size | Pre-trained on | Welsh 15.8 WER | SLURP Intent Classification Acc. |
|------------------|----------------------|--------------------|---------------------|
| Self-attention | 166M | LibriLight 4.3k h | 50.8 | 78.1 |
| SummaryMixing | 155M | LibriLight 4.3k h | 48.3 | 80.5 |
|------------------|----------------------|--------------------|---------------------|---------------------|
| w2v2 base | 95M | LibriSpeech 960 h | 54.5 | 77.7 |
| w2v2 large | 317M | LibriLight 60k h | 45.4 | 79.0 |


## Citation
Expand Down
195 changes: 195 additions & 0 deletions benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# ################################
# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0.

# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on CommonVoice cy with a LSTM downstream model

# Usage: Install SpeechBrain MP3S
# Create a folder benchmarks/MP3S/CommonVoice/LSTM
# Copy this file under benchmarks/MP3S/CommonVoice/LSTM/hparams
# SummaryMixing: https://arxiv.org/abs/2307.07421
# SummaryMixing SSL:

# Authors
# * Titouan Parcollet 2023, 2024
# * Shucong Zhang 2023, 2024
# * Rogier van Dalen 2023, 2024
# * Sourav Bhattacharya 2023, 2024
# ################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
language: cy # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
output_folder: !ref results/CommonVoice/<language>/<seed>
wer_file: !ref <output_folder>/wer.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/<language>
train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
accented_letters: True
train_csv: !ref <save_folder>/train.csv
valid_csv: !ref <save_folder>/dev.csv
test_csv: !ref <save_folder>/test.csv
skip_prep: False # Skip data preparation

avoid_if_longer_than: 10.0

num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large )
pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2
encoder_dim: 768

# Training parameters
number_of_epochs: 20
lr: 0.0004
lr_weights: 0.02
sorting: ascending
auto_mix_prec: False
sample_rate: 16000
token_type: bpe # ["unigram", "bpe", "char"]
character_coverage: 1.0


# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 3 per GPU to fit 32GB of VRAM
batch_size: 4
test_batch_size: 4

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
dataloader_options:
batch_size: !ref <batch_size>
num_workers: 4
test_dataloader_options:
batch_size: !ref <test_batch_size>
num_workers: 4


valid_dataloader_opts:
batch_size: !ref <batch_size>

# Model parameters
activation: !name:torch.nn.Sigmoid
dnn_layers: 1
dnn_neurons: 768
freeze_encoder: True

# Outputs
output_neurons: 100 # BPE size, index(blank/eos/bos) = 0


# Functions and classes
#
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
sample_rate: !ref <sample_rate>
speeds: [95, 100, 105]

encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder
d_model: 768
num_layers: 12
nhead: 8
d_ffn: 3072
dropout: 0.1
layerdrop_prob: 0.0
attention_type: SummaryMixing
local_proj_hid_dim: [768]
local_proj_out_dim: 768
summary_hid_dim: [768]
mode: SummaryMixing
output_hidden_states: True

latentextractor_kernels: [3, 3]
latentextractor_strides: [2, 1]
extractor_dim: 512
embedding_dim: 768

CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor
kernel_sizes: !ref <latentextractor_kernels>
strides: !ref <latentextractor_strides>
out_channels: [512, 512]
input_dim: 80

weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel
pretrained_path: !ref <pretrained_path>
num_layers: 13
latent_encoder: !ref <encoder>
CNN: !ref <CNN>
in_dim: 512
embedding_dim: 768
dropout_encoder_input: 0.1
output_hidden_states: True

enc: !new:speechbrain.nnet.RNN.LSTM
input_shape: [Null, Null, !ref <encoder_dim>]
num_layers: 2
bidirectional: True
dropout: 0.2
hidden_size: 1024

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: 2048
n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>

modules:
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>
weighted_ssl_model: !ref <weighted_ssl_model>

model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]

model_opt_class: !name:torch.optim.Adam
lr: !ref <lr>

weights_opt_class: !name:torch.optim.Adam
lr: !ref <lr_weights>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0

lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_weights>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0

label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
ssl_model: !ref <weighted_ssl_model>
scheduler_model: !ref <lr_annealing_model>
scheduler_encoder: !ref <lr_annealing_weights>
counter: !ref <epoch_counter>
tokenizer: !ref <label_encoder>

blank_index: 0
unk_index: 1


train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
Loading