Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(model) : add segmentation model based on self-supervised representation #1362

Merged
merged 25 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6d3af2e
add WaVLM-Base model to PyanNet.py in replacement of SincNet
SevKod May 5, 2023
d03906b
implement wavlm inside PyanNet class and add wavlm block
SevKod May 9, 2023
3fc2d37
add support of all Torchaudio self-supverised models to PyanNet, incl…
SevKod May 15, 2023
1e370fc
add support of ssl models from huggingface to pyannote using PyanHugg…
SevKod May 26, 2023
e170eed
remove support for sincnet block
SevKod May 31, 2023
9f81c30
Remove unnecessary computation for unused deeper layers.
SevKod Jun 5, 2023
e5330fc
add support for fairseq pretrained ssl models
SevKod Jun 20, 2023
7a21fc9
fairseq dependency only used if needed
SevKod Jun 20, 2023
6243f91
Merge branch 'develop' into PyanNetWavLM
hbredin Jun 20, 2023
328505c
Remove unnecessary computation for unused deeper layers (regarding a …
SevKod Jul 5, 2023
d4ddd53
Merge branch 'PyanNetWavLM' of github.com:SevKod/pyannote-audio into …
SevKod Jul 5, 2023
63a9e42
Merge branch 'develop' into PyanNetWavLM
hbredin Jul 6, 2023
cbd01a3
Remove HuggingFace and fairseq dependencies from self-sup
SevKod Jul 10, 2023
f608eb7
Merge branch 'PyanNetWavLM' of github.com:SevKod/pyannote-audio into …
SevKod Jul 10, 2023
d7e9203
add support for torchaudio self sup models
SevKod Jul 12, 2023
81aafdd
fixed bug condition of wavlm_base and wavlm_large
SevKod Jul 13, 2023
b9c89b6
add layer-wise pooling and finetuning (still wip)
SevKod Aug 2, 2023
8aba20e
Merge branch 'develop' into PyanNetWavLM
hbredin Aug 2, 2023
4a8bfe2
Merge branch 'develop' into PyanNetWavLM
hbredin Aug 2, 2023
2323105
Merge branch 'develop' into PyanNetWavLM
hbredin Sep 10, 2023
cedf042
feat: add SSeRiouSS architecture
hbredin Sep 13, 2023
06641bf
chore: remove old PyanSup
hbredin Sep 13, 2023
31d08a4
chore: remove now replaced SelfSupModel block
hbredin Sep 15, 2023
421ba03
doc: update changelog
hbredin Sep 15, 2023
5f9211c
Merge branch 'develop' into PyanNetWavLM
hbredin Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions pyannote/audio/models/blocks/selfsup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# MIT License
#
# Copyright (c) 2020 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import sys
import re
from typing import Optional
import torch
import torchaudio
import torch.nn as nn
from torch.nn.functional import normalize
import torch.nn.functional as F
from collections import OrderedDict
from torchaudio.models.wav2vec2 import wav2vec2_model, wavlm_model
from torchaudio.pipelines import Wav2Vec2Bundle

class SelfSupModel(nn.Module):

def __init__(self,checkpoint=None,torchaudio_ssl=None,torchaudio_cache=None,finetune=None,average_layers=None,average_all=None,name=None,layer=None,cfg=None):
super().__init__()
if torchaudio_ssl:
if checkpoint:
raise ValueError("Error : Cannot specify both a checkpoint and a torchaudio model.")

print("\nThe Self-Supervised Model "+str(torchaudio_ssl)+" is loaded from torchaudio.\n")
name,config,ordered_dict = self.dict_torchaudio(torchaudio_ssl,torchaudio_cache)
else:
print("A checkpoint from a Self-Supervised Model is used for training.")
if torch.cuda.is_available():
ckpt = torch.load(checkpoint)
else:
ckpt = torch.load(checkpoint,map_location=torch.device('cpu'))
#Check if the checkpoint is from an already finetuned Diarization model (containing SSL), or from a SSL pretrained model only
if 'pyannote.audio' in ckpt: #1: Check if there is a Segmentation model attached onto or not
print("The checkpoint is used for finetuning. \nThe attached SSL model will be used for feature extraction.")
name,config,ordered_dict = self.dict_finetune(ckpt)

else: #Otherwise, load the dictionary of the SSL checkpoint
print("The checkpoint is a pretrained SSL model to use for Segmentation.\nBuilding the SSL model.")
name,config,ordered_dict = self.dict_pretrained(ckpt)

# Layer-wise pooling (same way as SUPERB)
if not average_all:
if not average_layers :
if layer is None :
print("\nLayer number not specified. Default to layer 1.\n")

self.layer = 1
else :

self.layer = layer
print("\nSelected layer is "+ str(layer) +". \n")
else:
print("Layers "+str(average_layers)+" selected for layer-wise pooling.")

self.W = nn.Parameter(torch.randn(len(average_layers))) #Set specific number of learnable weights

self.average_layers = average_layers

self.layer = max(average_layers)
else:
print("All layers are selected for layer-wise pooling.")

self.W = nn.Parameter(torch.randn(config['encoder_num_layers'])) #Set max number of learnable weights

self.average_layers = list(range(config['encoder_num_layers']))

self.layer = config['encoder_num_layers']

if finetune: #Finetuning not working
print("Self-supervised model is unfrozen.")
#config['encoder_ff_interm_dropout'] = 0.3
config['encoder_layer_norm_first'] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue regarding the funetuning of WavLM seemed similar to a normalization issue that occured during the feature extraction process. If gradient is computed during training, validation will extract feature vectors that are almost identical amongst each frames of the input audio. I assume that it might be the reason why validation does not seem to improve (or change) during training (but I might be completely wrong on this...). Since this problem seemed similar to the one I encountered with WavLM from back a few months (and has been fixed), where features were the also the same amongst the frames, I tried applying a normalization step to see if the behavior of the features extracted would change. Did not seem to be the case... It is one of the many things I tried but forgot to remove when pushing codes ^^

else :
print("Self-supervised model is frozen.")

config['encoder_num_layers'] = self.layer
ordered_dict = self.remove_layers_dict(ordered_dict,self.layer) #Remove weights from unused transformer encoders
self.model_name = name
self.finetune = finetune #Assign mode
self.average_layers = average_layers
self.feat_size = config['encoder_embed_dim'] #Get feature output dimension
self.config = config #Assign the configuration
SelfSupModel.__name__ = self.model_name #Assign name of the class

if name is "WAVLM_BASE" or name is "WAVLM_LARGE": #Only wavlm_model has two additional arguments
model = wavlm_model(**config)
else:
model = wav2vec2_model(**config)
model.load_state_dict(ordered_dict) #Assign state dict to the model

if finetune:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Peux-tu m'expliquer pourquoi cela est nécessaire ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Je parle du passage au mode eval dans le cas où le WavLM est gelé.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

J'avais cru comprendre que le .eval() était pertinent lorsque certains modules tel que des couches de Dropout, sont présentes dans le modèle en question que l'on souhaite passer en mode inférence (le cas pour WavLM). J'avais lu ce post qui conseillait l'utilisation des deux :

https://stackoverflow.com/questions/55627780/evaluating-pytorch-models-with-torch-no-grad-vs-model-eval

Après, de souvenirs, je n'avais pas identifié de quelconque changement au niveau des features entre le ".eval()" et le "no_grad". Apparemment, le .eval() consommerait plus de mémoire que le no_grad aussi... Donc, pourquoi pas l'enlever. C'était aussi pour étudier les changements entre le .eval() et le .train() quand je voulais voir ce qui se passait au niveau du finetuning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.eval() et torch.no_grad() ont deux rôles bien différents.

  • torch.no_grad() supprimer le calcul du gradient des couches concernées et est donc utile quand tu veux geler une partie du réseau.
  • model.eval() passe les couches qui ont un comportement particulier lors de l'apprentissage (e.g. dropout qui désactive aléatoirement certains poids, ou batchnorm qui calcule une moyenne des données qui la traverse) en mode inférence pour supprimer tout cet aléa.

En résumé, il ne faut pas utiliser ni model.eval() ni model.train() pour contrôler si tu finetunes ou non la partie feature extraction. Il faut seulement utiliser torch.no_grad() (pour geler) ou pas (pour finetuner).

Le passage en mode eval ou train est effectué automatiquement par pytorch-lightning lors des phases de validation et d'apprentissage.

self.ssl_model = model.train()
else:
self.ssl_model = model.eval()


def dict_finetune(self, ckpt):
#Need to reconstruct the dictionary
#Get dict
dict_modules = list(ckpt['state_dict'].keys()) #Get the list of ssl modules
ssl_modules = [key for key in dict_modules if 'selfsupervised' in key] #Extract only the SSL parts
weights = [ckpt['state_dict'][key] for key in ssl_modules] #Get the weights corresponding to the modules
modules_torchaudio = ['.'.join(key.split('.')[2:]) for key in ssl_modules] #Get a new list which contains only torchaudio keywords
ordered_dict = OrderedDict((key,weight) for key,weight in zip(modules_torchaudio,weights)) #Recreate the state_dict
config = ckpt['hyper_parameters']['selfsupervised']['cfg'] #Get config
name = ckpt['hyper_parameters']['selfsupervised']['name'] #Get model name

return(name,config,ordered_dict)

def dict_pretrained(self, ckpt):
ordered_dict = ckpt['state_dict'] #Get dict
config = ckpt['config'] #Get config
name = ckpt['model_name'] #Get model name

return(ckpt['model_name'],ckpt['config'],ckpt['state_dict'])

def dict_torchaudio(self,torchaudio_ssl,torchaudio_cache):
bundle = getattr(torchaudio.pipelines, torchaudio_ssl)
#Name is torchaudio_ssl
name = torchaudio_ssl #Get name
config = bundle._params #Get config
if torchaudio_cache:
torch.hub.set_dir(torchaudio_cache) #Set cache
ordered_dict = bundle.get_model().state_dict() #Get the dict

return(name,config,ordered_dict)
def remove_layers_dict(self,state_dict,layer):
keys_to_delete = []
for key in state_dict.keys():
if "transformer.layers" in key:
nb = int(re.findall(r'\d+',key)[0])
if nb>(layer-1):
keys_to_delete.append(key)
for key in keys_to_delete:
del state_dict[key]

return(state_dict)

def avg_pool(self,scalars,feat_list):
sum = 0
for i in range(0,len(feat_list)):
sum = sum + scalars[i]*feat_list[i]
return(sum)

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample)
if self.finetune:
feat,_ = self.ssl_model.extract_features(waveforms,None,self.layer)
else:
with torch.no_grad():
feat,_ = self.ssl_model.extract_features(waveforms,None,self.layer)
if self.average_layers:
feat_learn_list = []
for index in self.average_layers:
feat_learn_list.append(feat[index-1])
w = self.W.softmax(-1)
outputs = self.avg_pool(w,feat_learn_list)
#print(w)
#print(outputs.size())
else:
outputs = feat[self.layer-1]
return (outputs)
183 changes: 183 additions & 0 deletions pyannote/audio/models/segmentation/PyanSup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from pyannote.core.utils.generators import pairwise

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.models.blocks.selfsup import SelfSupModel
from pyannote.audio.utils.params import merge_dict


class PyanSup(Model):
"""PyanHugg segmentation model

Self-Supervised Model > LSTM > Feed forward > Classifier

Parameters
----------
sample_rate : int, optional
Audio sample rate. Defaults to 16kHz (16000).
num_channels : int, optional
Number of channels. Defaults to mono (1).

selfsupervised : dict, optional
Keyword arugments passed to the selfsupervised block. name and cfg are used to reconstruct the feature extractor from the dictionary of the checkpoint. Layer corresponds to the layer that serves for the feature extraction.
Defaults to {
"name": None,
"layer": None,
"cfg": None,
}
lstm : dict, optional
Keyword arguments passed to the LSTM layer.
Defaults to {"hidden_size": 128, "num_layers": 2, "bidirectional": True},
i.e. two bidirectional layers with 128 units each.
Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs.
This may proove useful for probing LSTM internals.
linear : dict, optional
Keyword arugments used to initialize linear layers
Defaults to {"hidden_size": 128, "num_layers": 2},
i.e. two linear layers with 128 units each.
"""



SSL_DEFAULTS = {
"name": None,
"layer": None,
"average_layers": None,
"average_all": False,
"cfg": None,
}
LSTM_DEFAULTS = {
"hidden_size": 128,
"num_layers": 2,
"bidirectional": True,
"monolithic": True,
"dropout": 0.0,
}
LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2}
def __init__(
self,
ckpt: str = None,
torchaudio_ssl: str = None,
torchaudio_cache: str = None,
finetune: bool = False,
selfsupervised: dict = None,
lstm: dict = None,
linear: dict = None,
sample_rate: int = 16000,
num_channels: int = 1,
task: Optional[Task] = None,
):

super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)

selfsupervised = merge_dict(self.SSL_DEFAULTS, selfsupervised)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
lstm["batch_first"] = True
linear = merge_dict(self.LINEAR_DEFAULTS, linear)
self.save_hyperparameters("lstm", "linear") #A first merge is done using the default parameters specified

print("\n##################################################################")
print("### A self-supervised model is used for the feature extraction ###")
print("##################################################################")
self.selfsupervised = SelfSupModel(ckpt,torchaudio_ssl,torchaudio_cache,finetune,**selfsupervised)
selfsupervised['name'] = self.selfsupervised.model_name
selfsupervised['cfg'] = self.selfsupervised.config
self.save_hyperparameters("selfsupervised")
feat_size = self.selfsupervised.feat_size

print("##################################################################\n")

monolithic = lstm["monolithic"]
if monolithic:
multi_layer_lstm = dict(lstm)
del multi_layer_lstm["monolithic"]
self.lstm = nn.LSTM(feat_size, **multi_layer_lstm)

else:
num_layers = lstm["num_layers"]
if num_layers > 1:
self.dropout = nn.Dropout(p=lstm["dropout"])

one_layer_lstm = dict(lstm)
one_layer_lstm["num_layers"] = 1
one_layer_lstm["dropout"] = 0.0
del one_layer_lstm["monolithic"]

self.lstm = nn.ModuleList(
[
nn.LSTM(
feat_size
if i == 0
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
**one_layer_lstm
)
for i in range(num_layers)
]
)

if linear["num_layers"] < 1:
return

lstm_out_features: int = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)
self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
for in_features, out_features in pairwise(
[
lstm_out_features,
]
+ [self.hparams.linear["hidden_size"]]
* self.hparams.linear["num_layers"]
)
]
)
def build(self):

if self.hparams.linear["num_layers"] > 0:
in_features = self.hparams.linear["hidden_size"]
else:
in_features = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)

if self.specifications.powerset:
out_features = self.specifications.num_powerset_classes
else:
out_features = len(self.specifications.classes)

self.classifier = nn.Linear(in_features, out_features)
self.activation = self.default_activation()

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward

Parameters
----------
waveforms : (batch, channel, sample)

Returns
-------
scores : (batch, frame, classes)
"""
outputs = self.selfsupervised(waveforms)
if self.hparams.lstm["monolithic"]:
outputs, _ = self.lstm(outputs)
else:
for i, lstm in enumerate(self.lstm):
outputs, _ = lstm(outputs)
if i + 1 < self.hparams.lstm["num_layers"]:
outputs = self.dropout(outputs)

if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))

return self.activation(self.classifier(outputs))
3 changes: 2 additions & 1 deletion pyannote/audio/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
# SOFTWARE.

from .PyanNet import PyanNet
from .PyanSup import PyanSup

__all__ = ["PyanNet"]
__all__ = ["PyanNet","PyanSup"]
Loading