-
-
Notifications
You must be signed in to change notification settings - Fork 758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(model) : add segmentation model based on self-supervised representation #1362
Changes from 20 commits
6d3af2e
d03906b
3fc2d37
1e370fc
e170eed
9f81c30
e5330fc
7a21fc9
6243f91
328505c
d4ddd53
63a9e42
cbd01a3
f608eb7
d7e9203
81aafdd
b9c89b6
8aba20e
4a8bfe2
2323105
cedf042
06641bf
31d08a4
421ba03
5f9211c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# MIT License | ||
# | ||
# Copyright (c) 2020 CNRS | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all | ||
# copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
|
||
import sys | ||
import re | ||
from typing import Optional | ||
import torch | ||
import torchaudio | ||
import torch.nn as nn | ||
from torch.nn.functional import normalize | ||
import torch.nn.functional as F | ||
from collections import OrderedDict | ||
from torchaudio.models.wav2vec2 import wav2vec2_model, wavlm_model | ||
from torchaudio.pipelines import Wav2Vec2Bundle | ||
|
||
class SelfSupModel(nn.Module): | ||
|
||
def __init__(self,checkpoint=None,torchaudio_ssl=None,torchaudio_cache=None,finetune=None,average_layers=None,average_all=None,name=None,layer=None,cfg=None): | ||
super().__init__() | ||
if torchaudio_ssl: | ||
if checkpoint: | ||
raise ValueError("Error : Cannot specify both a checkpoint and a torchaudio model.") | ||
|
||
print("\nThe Self-Supervised Model "+str(torchaudio_ssl)+" is loaded from torchaudio.\n") | ||
name,config,ordered_dict = self.dict_torchaudio(torchaudio_ssl,torchaudio_cache) | ||
else: | ||
print("A checkpoint from a Self-Supervised Model is used for training.") | ||
if torch.cuda.is_available(): | ||
ckpt = torch.load(checkpoint) | ||
else: | ||
ckpt = torch.load(checkpoint,map_location=torch.device('cpu')) | ||
#Check if the checkpoint is from an already finetuned Diarization model (containing SSL), or from a SSL pretrained model only | ||
if 'pyannote.audio' in ckpt: #1: Check if there is a Segmentation model attached onto or not | ||
print("The checkpoint is used for finetuning. \nThe attached SSL model will be used for feature extraction.") | ||
name,config,ordered_dict = self.dict_finetune(ckpt) | ||
|
||
else: #Otherwise, load the dictionary of the SSL checkpoint | ||
print("The checkpoint is a pretrained SSL model to use for Segmentation.\nBuilding the SSL model.") | ||
name,config,ordered_dict = self.dict_pretrained(ckpt) | ||
|
||
# Layer-wise pooling (same way as SUPERB) | ||
if not average_all: | ||
if not average_layers : | ||
if layer is None : | ||
print("\nLayer number not specified. Default to layer 1.\n") | ||
|
||
self.layer = 1 | ||
else : | ||
|
||
self.layer = layer | ||
print("\nSelected layer is "+ str(layer) +". \n") | ||
else: | ||
print("Layers "+str(average_layers)+" selected for layer-wise pooling.") | ||
|
||
self.W = nn.Parameter(torch.randn(len(average_layers))) #Set specific number of learnable weights | ||
|
||
self.average_layers = average_layers | ||
|
||
self.layer = max(average_layers) | ||
else: | ||
print("All layers are selected for layer-wise pooling.") | ||
|
||
self.W = nn.Parameter(torch.randn(config['encoder_num_layers'])) #Set max number of learnable weights | ||
|
||
self.average_layers = list(range(config['encoder_num_layers'])) | ||
|
||
self.layer = config['encoder_num_layers'] | ||
|
||
if finetune: #Finetuning not working | ||
print("Self-supervised model is unfrozen.") | ||
#config['encoder_ff_interm_dropout'] = 0.3 | ||
config['encoder_layer_norm_first'] = True | ||
else : | ||
print("Self-supervised model is frozen.") | ||
|
||
config['encoder_num_layers'] = self.layer | ||
ordered_dict = self.remove_layers_dict(ordered_dict,self.layer) #Remove weights from unused transformer encoders | ||
self.model_name = name | ||
self.finetune = finetune #Assign mode | ||
self.average_layers = average_layers | ||
self.feat_size = config['encoder_embed_dim'] #Get feature output dimension | ||
self.config = config #Assign the configuration | ||
SelfSupModel.__name__ = self.model_name #Assign name of the class | ||
|
||
if name is "WAVLM_BASE" or name is "WAVLM_LARGE": #Only wavlm_model has two additional arguments | ||
model = wavlm_model(**config) | ||
else: | ||
model = wav2vec2_model(**config) | ||
model.load_state_dict(ordered_dict) #Assign state dict to the model | ||
|
||
if finetune: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Peux-tu m'expliquer pourquoi cela est nécessaire ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Je parle du passage au mode There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. J'avais cru comprendre que le .eval() était pertinent lorsque certains modules tel que des couches de Dropout, sont présentes dans le modèle en question que l'on souhaite passer en mode inférence (le cas pour WavLM). J'avais lu ce post qui conseillait l'utilisation des deux : Après, de souvenirs, je n'avais pas identifié de quelconque changement au niveau des features entre le ".eval()" et le "no_grad". Apparemment, le .eval() consommerait plus de mémoire que le no_grad aussi... Donc, pourquoi pas l'enlever. C'était aussi pour étudier les changements entre le .eval() et le .train() quand je voulais voir ce qui se passait au niveau du finetuning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
En résumé, il ne faut pas utiliser ni Le passage en mode |
||
self.ssl_model = model.train() | ||
else: | ||
self.ssl_model = model.eval() | ||
|
||
|
||
def dict_finetune(self, ckpt): | ||
#Need to reconstruct the dictionary | ||
#Get dict | ||
dict_modules = list(ckpt['state_dict'].keys()) #Get the list of ssl modules | ||
ssl_modules = [key for key in dict_modules if 'selfsupervised' in key] #Extract only the SSL parts | ||
weights = [ckpt['state_dict'][key] for key in ssl_modules] #Get the weights corresponding to the modules | ||
modules_torchaudio = ['.'.join(key.split('.')[2:]) for key in ssl_modules] #Get a new list which contains only torchaudio keywords | ||
ordered_dict = OrderedDict((key,weight) for key,weight in zip(modules_torchaudio,weights)) #Recreate the state_dict | ||
config = ckpt['hyper_parameters']['selfsupervised']['cfg'] #Get config | ||
name = ckpt['hyper_parameters']['selfsupervised']['name'] #Get model name | ||
|
||
return(name,config,ordered_dict) | ||
|
||
def dict_pretrained(self, ckpt): | ||
ordered_dict = ckpt['state_dict'] #Get dict | ||
config = ckpt['config'] #Get config | ||
name = ckpt['model_name'] #Get model name | ||
|
||
return(ckpt['model_name'],ckpt['config'],ckpt['state_dict']) | ||
|
||
def dict_torchaudio(self,torchaudio_ssl,torchaudio_cache): | ||
bundle = getattr(torchaudio.pipelines, torchaudio_ssl) | ||
#Name is torchaudio_ssl | ||
name = torchaudio_ssl #Get name | ||
config = bundle._params #Get config | ||
if torchaudio_cache: | ||
torch.hub.set_dir(torchaudio_cache) #Set cache | ||
ordered_dict = bundle.get_model().state_dict() #Get the dict | ||
|
||
return(name,config,ordered_dict) | ||
def remove_layers_dict(self,state_dict,layer): | ||
keys_to_delete = [] | ||
for key in state_dict.keys(): | ||
if "transformer.layers" in key: | ||
nb = int(re.findall(r'\d+',key)[0]) | ||
if nb>(layer-1): | ||
keys_to_delete.append(key) | ||
for key in keys_to_delete: | ||
del state_dict[key] | ||
|
||
return(state_dict) | ||
|
||
def avg_pool(self,scalars,feat_list): | ||
sum = 0 | ||
for i in range(0,len(feat_list)): | ||
sum = sum + scalars[i]*feat_list[i] | ||
return(sum) | ||
|
||
def forward(self, waveforms: torch.Tensor) -> torch.Tensor: | ||
waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample) | ||
if self.finetune: | ||
feat,_ = self.ssl_model.extract_features(waveforms,None,self.layer) | ||
else: | ||
with torch.no_grad(): | ||
feat,_ = self.ssl_model.extract_features(waveforms,None,self.layer) | ||
if self.average_layers: | ||
feat_learn_list = [] | ||
for index in self.average_layers: | ||
feat_learn_list.append(feat[index-1]) | ||
w = self.W.softmax(-1) | ||
outputs = self.avg_pool(w,feat_learn_list) | ||
#print(w) | ||
#print(outputs.size()) | ||
else: | ||
outputs = feat[self.layer-1] | ||
return (outputs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
from pyannote.core.utils.generators import pairwise | ||
|
||
from pyannote.audio.core.model import Model | ||
from pyannote.audio.core.task import Task | ||
from pyannote.audio.models.blocks.selfsup import SelfSupModel | ||
from pyannote.audio.utils.params import merge_dict | ||
|
||
|
||
class PyanSup(Model): | ||
"""PyanHugg segmentation model | ||
|
||
Self-Supervised Model > LSTM > Feed forward > Classifier | ||
|
||
Parameters | ||
---------- | ||
sample_rate : int, optional | ||
Audio sample rate. Defaults to 16kHz (16000). | ||
num_channels : int, optional | ||
Number of channels. Defaults to mono (1). | ||
|
||
selfsupervised : dict, optional | ||
Keyword arugments passed to the selfsupervised block. name and cfg are used to reconstruct the feature extractor from the dictionary of the checkpoint. Layer corresponds to the layer that serves for the feature extraction. | ||
Defaults to { | ||
"name": None, | ||
"layer": None, | ||
"cfg": None, | ||
} | ||
lstm : dict, optional | ||
Keyword arguments passed to the LSTM layer. | ||
Defaults to {"hidden_size": 128, "num_layers": 2, "bidirectional": True}, | ||
i.e. two bidirectional layers with 128 units each. | ||
Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs. | ||
This may proove useful for probing LSTM internals. | ||
linear : dict, optional | ||
Keyword arugments used to initialize linear layers | ||
Defaults to {"hidden_size": 128, "num_layers": 2}, | ||
i.e. two linear layers with 128 units each. | ||
""" | ||
|
||
|
||
|
||
SSL_DEFAULTS = { | ||
"name": None, | ||
"layer": None, | ||
"average_layers": None, | ||
"average_all": False, | ||
"cfg": None, | ||
} | ||
LSTM_DEFAULTS = { | ||
"hidden_size": 128, | ||
"num_layers": 2, | ||
"bidirectional": True, | ||
"monolithic": True, | ||
"dropout": 0.0, | ||
} | ||
LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} | ||
def __init__( | ||
self, | ||
ckpt: str = None, | ||
torchaudio_ssl: str = None, | ||
torchaudio_cache: str = None, | ||
finetune: bool = False, | ||
selfsupervised: dict = None, | ||
lstm: dict = None, | ||
linear: dict = None, | ||
sample_rate: int = 16000, | ||
num_channels: int = 1, | ||
task: Optional[Task] = None, | ||
): | ||
|
||
super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) | ||
|
||
selfsupervised = merge_dict(self.SSL_DEFAULTS, selfsupervised) | ||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||
lstm["batch_first"] = True | ||
linear = merge_dict(self.LINEAR_DEFAULTS, linear) | ||
self.save_hyperparameters("lstm", "linear") #A first merge is done using the default parameters specified | ||
|
||
print("\n##################################################################") | ||
print("### A self-supervised model is used for the feature extraction ###") | ||
print("##################################################################") | ||
self.selfsupervised = SelfSupModel(ckpt,torchaudio_ssl,torchaudio_cache,finetune,**selfsupervised) | ||
selfsupervised['name'] = self.selfsupervised.model_name | ||
selfsupervised['cfg'] = self.selfsupervised.config | ||
self.save_hyperparameters("selfsupervised") | ||
feat_size = self.selfsupervised.feat_size | ||
|
||
print("##################################################################\n") | ||
|
||
monolithic = lstm["monolithic"] | ||
if monolithic: | ||
multi_layer_lstm = dict(lstm) | ||
del multi_layer_lstm["monolithic"] | ||
self.lstm = nn.LSTM(feat_size, **multi_layer_lstm) | ||
|
||
else: | ||
num_layers = lstm["num_layers"] | ||
if num_layers > 1: | ||
self.dropout = nn.Dropout(p=lstm["dropout"]) | ||
|
||
one_layer_lstm = dict(lstm) | ||
one_layer_lstm["num_layers"] = 1 | ||
one_layer_lstm["dropout"] = 0.0 | ||
del one_layer_lstm["monolithic"] | ||
|
||
self.lstm = nn.ModuleList( | ||
[ | ||
nn.LSTM( | ||
feat_size | ||
if i == 0 | ||
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), | ||
**one_layer_lstm | ||
) | ||
for i in range(num_layers) | ||
] | ||
) | ||
|
||
if linear["num_layers"] < 1: | ||
return | ||
|
||
lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( | ||
2 if self.hparams.lstm["bidirectional"] else 1 | ||
) | ||
self.linear = nn.ModuleList( | ||
[ | ||
nn.Linear(in_features, out_features) | ||
for in_features, out_features in pairwise( | ||
[ | ||
lstm_out_features, | ||
] | ||
+ [self.hparams.linear["hidden_size"]] | ||
* self.hparams.linear["num_layers"] | ||
) | ||
] | ||
) | ||
def build(self): | ||
|
||
if self.hparams.linear["num_layers"] > 0: | ||
in_features = self.hparams.linear["hidden_size"] | ||
else: | ||
in_features = self.hparams.lstm["hidden_size"] * ( | ||
2 if self.hparams.lstm["bidirectional"] else 1 | ||
) | ||
|
||
if self.specifications.powerset: | ||
out_features = self.specifications.num_powerset_classes | ||
else: | ||
out_features = len(self.specifications.classes) | ||
|
||
self.classifier = nn.Linear(in_features, out_features) | ||
self.activation = self.default_activation() | ||
|
||
def forward(self, waveforms: torch.Tensor) -> torch.Tensor: | ||
"""Pass forward | ||
|
||
Parameters | ||
---------- | ||
waveforms : (batch, channel, sample) | ||
|
||
Returns | ||
------- | ||
scores : (batch, frame, classes) | ||
""" | ||
outputs = self.selfsupervised(waveforms) | ||
if self.hparams.lstm["monolithic"]: | ||
outputs, _ = self.lstm(outputs) | ||
else: | ||
for i, lstm in enumerate(self.lstm): | ||
outputs, _ = lstm(outputs) | ||
if i + 1 < self.hparams.lstm["num_layers"]: | ||
outputs = self.dropout(outputs) | ||
|
||
if self.hparams.linear["num_layers"] > 0: | ||
for linear in self.linear: | ||
outputs = F.leaky_relu(linear(outputs)) | ||
|
||
return self.activation(self.classifier(outputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue regarding the funetuning of WavLM seemed similar to a normalization issue that occured during the feature extraction process. If gradient is computed during training, validation will extract feature vectors that are almost identical amongst each frames of the input audio. I assume that it might be the reason why validation does not seem to improve (or change) during training (but I might be completely wrong on this...). Since this problem seemed similar to the one I encountered with WavLM from back a few months (and has been fixed), where features were the also the same amongst the frames, I tried applying a normalization step to see if the behavior of the features extracted would change. Did not seem to be the case... It is one of the many things I tried but forgot to remove when pushing codes ^^