Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is24 pr #12

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
# SummaryMixing for SpeechBrain v1.0
*Halve your VRAM requirements and train 30% faster any speech model achieving equivalents or better Word Error Rates and SLU accuracies with SummaryMixing Conformers and Branchformers.*

*Reduce your self-supervised learning (SSL) pre-training time and VRAM requirements by 20%-30% with equivalents or better downstream performan on speech processing tasks.*

## In brief
SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear).

This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the [publication](https://arxiv.org/abs/2307.07421) for further details).

This repository also implements SummaryMixing for SSL pre-training (see: the [publication](https://arxiv.org/pdf/2407.13377) for further details) and streaming transducer.

The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) toolkit -- copy and paste is all you need to start using SummaryMixing in your setup.

## !! A word about using SummaryMixing with SpeechBrain V1.0 !!

The main branch of this repository will keep tracking the latest version of SpeechBrain available. Unfortunately the results reported in our [publication](https://arxiv.org/abs/2307.07421) and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated
The main branch of this repository will keep tracking the latest version of SpeechBrain available. The results for SSL in our [publication](https://arxiv.org/pdf/2407.13377) and the streaming transducer were obtained with SpeechBrain v1.0. For the Conformer attention-CTC models with SpeechBrain v1.0, below are the results:

| Encoder | Variant | Dev-clean | Test-clean | Test-other |
|------------------|----------------------|--------------------|---------------------|---------------------|
| | | **WER \%** | **WER \%** | **WER \%** | **hours** | **GB** |
| Conformer | Self-attention | 1.9 | 2.0 | 4.6 |
| Conformer | SummaryMixing | 1.9 | 2.0 | 4.6 |

Unfortunately the results reported in our [publication](https://arxiv.org/abs/2307.07421) and bellow in the Table were obtained with SpeechBrain v0.5 and may not be exactly reproduced with the current code. If you want the exact same results, please use our dedicated
[branch](https://github.com/SamsungLabs/SummaryMixing/tree/speechbrain_v0.5) that contains the code compatible with SpeechBrain v0.5!

## In brief
This repository implements SummaryMixing, a simpler, faster and much cheaper replacement to self-attention in Conformers and Branchformers for automatic speech recognition, keyword spotting and intent classification (see: the [publication](https://arxiv.org/abs/2307.07421) for further details). The code is fully compatible with the [SpeechBrain](https://speechbrain.github.io/) toolkit with version 0.5 -- copy and paste is all you need to start using SummaryMixing in your setup. If you wish to run with the latest version of SpeechBrain (v1.0+), please go to the main branch of this repository. SummaryMixing is the first alternative to MHSA able to beat it on speech tasks while reducing its complexity significantly (from quadratic to linear).

## A glance at SummaryMixing

Expand Down Expand Up @@ -44,13 +61,23 @@ Please cite SummaryMixing as follows:
```bibtex
@misc{summarymixing,
title={{SummaryMixing}: A Linear-Complexity Alternative to Self-Attention for Speech Recognition and Understanding},
author={Titouan Parcollet and Rogier van Dalen and and Shucong Zhang and Sourav Bhattacharya},
author={Titouan Parcollet and Rogier van Dalen and Shucong Zhang and Sourav Bhattacharya},
year={2023},
eprint={2307.07421},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2307.07421}
}

@misc{linear_ssl,
title={Linear-Complexity Self-Supervised Learning for Speech Processing},
author={Shucong Zhang and Titouan Parcollet and Rogier van Dalen and Sourav Bhattacharya},
year={2024},
eprint={2407.13377},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2407.13377}
}
```

## Licence
Expand Down
195 changes: 195 additions & 0 deletions benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl_cy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# ################################
# SummaryMixing © 2023 by Samsung Electronics is licensed under CC BY-NC 4.0.

# This is configurations of SummaryMixing wav2vec 2.0 downstream ASR on CommonVoice cy with a LSTM downstream model

# Usage: Install SpeechBrain MP3S
# Create a folder benchmarks/MP3S/CommonVoice/LSTM
# Copy this file under benchmarks/MP3S/CommonVoice/LSTM/hparams
# SummaryMixing: https://arxiv.org/abs/2307.07421
# SummaryMixing SSL: https://arxiv.org/pdf/2407.13377

# Authors
# * Titouan Parcollet 2023, 2024
# * Shucong Zhang 2023, 2024
# * Rogier van Dalen 2023, 2024
# * Sourav Bhattacharya 2023, 2024
# ################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
language: cy # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
output_folder: !ref results/CommonVoice/<language>/<seed>
wer_file: !ref <output_folder>/wer.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
data_folder: !PLACEHOLDER # e.g, /local/cv-corpus-11.0-2022-09-21/<language>
train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
accented_letters: True
train_csv: !ref <save_folder>/train.csv
valid_csv: !ref <save_folder>/dev.csv
test_csv: !ref <save_folder>/test.csv
skip_prep: False # Skip data preparation

avoid_if_longer_than: 10.0

num_layers_ssl: 13 #Number of layers in the SSL model (should be 25 for large )
pretrained_path: !PLACEHOLDER # e,g./path/to/pre-trained_SummaryMixing_w2v2
encoder_dim: 768

# Training parameters
number_of_epochs: 20
lr: 0.0004
lr_weights: 0.02
sorting: ascending
auto_mix_prec: False
sample_rate: 16000
token_type: bpe # ["unigram", "bpe", "char"]
character_coverage: 1.0


# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 3 per GPU to fit 32GB of VRAM
batch_size: 4
test_batch_size: 4

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
dataloader_options:
batch_size: !ref <batch_size>
num_workers: 4
test_dataloader_options:
batch_size: !ref <test_batch_size>
num_workers: 4


valid_dataloader_opts:
batch_size: !ref <batch_size>

# Model parameters
activation: !name:torch.nn.Sigmoid
dnn_layers: 1
dnn_neurons: 768
freeze_encoder: True

# Outputs
output_neurons: 100 # BPE size, index(blank/eos/bos) = 0


# Functions and classes
#
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
sample_rate: !ref <sample_rate>
speeds: [95, 100, 105]

encoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder
d_model: 768
num_layers: 12
nhead: 8
d_ffn: 3072
dropout: 0.1
layerdrop_prob: 0.0
attention_type: SummaryMixing
local_proj_hid_dim: [768]
local_proj_out_dim: 768
summary_hid_dim: [768]
mode: SummaryMixing
output_hidden_states: True

latentextractor_kernels: [3, 3]
latentextractor_strides: [2, 1]
extractor_dim: 512
embedding_dim: 768

CNN: !new:speechbrain.lobes.models.wav2vec.W2VLatentExtractor
kernel_sizes: !ref <latentextractor_kernels>
strides: !ref <latentextractor_strides>
out_channels: [512, 512]
input_dim: 80

weighted_ssl_model: !new:speechbrain.lobes.models.wav2vec.WeightedSSLModel
pretrained_path: !ref <pretrained_path>
num_layers: 13
latent_encoder: !ref <encoder>
CNN: !ref <CNN>
in_dim: 512
embedding_dim: 768
dropout_encoder_input: 0.1
output_hidden_states: True

enc: !new:speechbrain.nnet.RNN.LSTM
input_shape: [Null, Null, !ref <encoder_dim>]
num_layers: 2
bidirectional: True
dropout: 0.2
hidden_size: 1024

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: 2048
n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>

modules:
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>
weighted_ssl_model: !ref <weighted_ssl_model>

model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]

model_opt_class: !name:torch.optim.Adam
lr: !ref <lr>

weights_opt_class: !name:torch.optim.Adam
lr: !ref <lr_weights>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0

lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_weights>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0

label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
ssl_model: !ref <weighted_ssl_model>
scheduler_model: !ref <lr_annealing_model>
scheduler_encoder: !ref <lr_annealing_weights>
counter: !ref <epoch_counter>
tokenizer: !ref <label_encoder>

blank_index: 0
unk_index: 1


train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
Loading