Skip to content

Commit

Permalink
implement wavlm inside PyanNet class and add wavlm block
Browse files Browse the repository at this point in the history
  • Loading branch information
SevKod committed May 9, 2023
1 parent 6d3af2e commit d03906b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 63 deletions.
44 changes: 44 additions & 0 deletions pyannote/audio/models/blocks/wavlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# MIT License
#
# Copyright (c) 2020 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

This comment has been minimized.

Copy link
@hbredin

hbredin May 9, 2023

Member

Did you also try the model from torchaudio?

That would remove this additional transformers dependency.

Other criteria to decide between torchaudio and transformers include:

  • is it possible to extract feature from different layers (and not only at the end of the network)?
  • is it easy to fine-tune?

This comment has been minimized.

Copy link
@SevKod

SevKod May 15, 2023

Author Contributor

Did you also try the model from torchaudio?

I did. Main issue is that the WavLM model (specifically) from torchaudio is faulty ONLY when using "torch.no_grad(). Made a quick example below :

https://colab.research.google.com/drive/1Fk2UHFEb5Ae_QBQMxz29dRcGBNIcEATf?usp=sharing

This issue got fixed very recently here (pytorch/audio#3219) but is only available in the Preview (nightly) version of PyTorch (2.1.0). I managed to get the model working with Pyannote on this specific version.

Other criteria to decide between torchaudio and transformers include:

is it possible to extract feature from different layers (and not only at the end of the network)?

It is indeed possible for both models to extract the features (hidden state) for a specific layer.

is it easy to fine-tune?

I managed to follow the Pyannote fine-tuning section tutorial without any noticeable issue for both models.

Main thing to note is that overall, the 'torchaudio' model seems to compute the feature extraction a little bit quicker than the 'transformers' one, which makes the training faster. After comparing the results, both models show similar performances (mainly tested on a Segmentation task...).

I will commit a Torchaudio version...


class WavLM(nn.Module):

def __init__(self):
super().__init__()

self.wvlm = AutoModel.from_pretrained('microsoft/wavlm-base') #Load the model

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:

waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample)
with torch.no_grad():
outputs = self.wvlm(waveforms).extract_features #Compute the features and extract last hidden layer weights

return (outputs)
90 changes: 27 additions & 63 deletions pyannote/audio/models/segmentation/PyanNet.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
# MIT License
#
# Copyright (c) 2020 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import Optional

import torch
Expand All @@ -32,22 +9,9 @@
from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.models.blocks.sincnet import SincNet
from pyannote.audio.models.blocks.wavlm import WavLM
from pyannote.audio.utils.params import merge_dict

##WAVLM_BASE
#Requires to pass the PyanNet model to cuda during training script

#Model is loaded outside of the PyanNet class

from transformers import AutoModel

#Loading the model from HuggingFace (requires git lfs to load the .bin checkpoint)
#model = AutoModel.from_pretrained('/content/drive/MyDrive/PyanNet/wavlm-base')

model = AutoModel.from_pretrained('microsoft/wavlm-base')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) #Pass the model to the gpu (supposing that accelerator = gpu in the TorchLightning Trainer)

class PyanNet(Model):
"""PyanNet segmentation model
Expand Down Expand Up @@ -76,7 +40,6 @@ class PyanNet(Model):
"""

SINCNET_DEFAULTS = {"stride": 10}

LSTM_DEFAULTS = {
"hidden_size": 128,
"num_layers": 2,
Expand All @@ -88,6 +51,7 @@ class PyanNet(Model):

def __init__(
self,
model: str = None,
sincnet: dict = None,
lstm: dict = None,
linear: dict = None,
Expand All @@ -104,15 +68,21 @@ def __init__(
lstm["batch_first"] = True
linear = merge_dict(self.LINEAR_DEFAULTS, linear)
self.save_hyperparameters("sincnet", "lstm", "linear")
self.model = model

if model == "wavlm":
self.wavlm = WavLM()
feat_size = 512
else :
self.sincnet = SincNet(**self.hparams.sincnet)
feat_size = 60

self.sincnet = SincNet(**self.hparams.sincnet)


monolithic = lstm["monolithic"]
if monolithic:
multi_layer_lstm = dict(lstm)
del multi_layer_lstm["monolithic"]
self.lstm = nn.LSTM(512, **multi_layer_lstm)
self.lstm = nn.LSTM(feat_size, **multi_layer_lstm)

else:
num_layers = lstm["num_layers"]
if num_layers > 1:
Expand All @@ -126,7 +96,7 @@ def __init__(
self.lstm = nn.ModuleList(
[
nn.LSTM(
512
feat_size
if i == 0
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
**one_layer_lstm
Expand Down Expand Up @@ -182,34 +152,28 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
-------
scores : (batch, frame, classes)
"""
#outputs = self.sincnet(waveforms)

#WavLM feature extraction

waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample)
with torch.no_grad():
feat = model(waveforms) #Compute the features and extract last hidden layer weights

outputs = feat.extract_features #Get the features : outputs : (batch, frame, feature)

if self.model == "wavlm" :
outputs = self.wavlm(waveforms)
else :
outputs = self.sincnet(waveforms)

if self.hparams.lstm["monolithic"]:
#No need to rearrange the output, as the features are already structured in (batch frame feature)

#outputs, _ = self.lstm(
# rearrange(outputs, "batch feature frame -> batch frame feature"))
outputs, _ = self.lstm(outputs)

if self.model == "wavlm":
outputs, _ = self.lstm(outputs)
else:
outputs, _ = self.lstm(
rearrange(outputs, "batch feature frame -> batch frame feature")
)
else:
#outputs = rearrange(outputs, "batch feature frame -> batch frame feature").cuda()
if self.model != "wavlm":
outputs = rearrange(outputs, "batch feature frame -> batch frame feature")
for i, lstm in enumerate(self.lstm):
outputs, _ = lstm(outputs)
if i + 1 < self.hparams.lstm["num_layers"]:
outputs = self.dropout(outputs)



if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))

return self.activation(self.classifier(outputs))

0 comments on commit d03906b

Please sign in to comment.