Skip to content

Commit

Permalink
add WaVLM-Base model to PyanNet.py in replacement of SincNet
Browse files Browse the repository at this point in the history
Added WavLM-Base model which replaces the SincNet feature extraction model
within the PyanNet architecture (loaded outside of the class from
HuggingFace.co).
  • Loading branch information
SevKod committed May 5, 2023
1 parent 11b56a1 commit 6d3af2e
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions pyannote/audio/models/segmentation/PyanNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@
from pyannote.audio.models.blocks.sincnet import SincNet
from pyannote.audio.utils.params import merge_dict

##WAVLM_BASE
#Requires to pass the PyanNet model to cuda during training script

#Model is loaded outside of the PyanNet class

from transformers import AutoModel

#Loading the model from HuggingFace (requires git lfs to load the .bin checkpoint)
#model = AutoModel.from_pretrained('/content/drive/MyDrive/PyanNet/wavlm-base')

model = AutoModel.from_pretrained('microsoft/wavlm-base')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) #Pass the model to the gpu (supposing that accelerator = gpu in the TorchLightning Trainer)

class PyanNet(Model):
"""PyanNet segmentation model
Expand Down Expand Up @@ -62,6 +76,7 @@ class PyanNet(Model):
"""

SINCNET_DEFAULTS = {"stride": 10}

LSTM_DEFAULTS = {
"hidden_size": 128,
"num_layers": 2,
Expand Down Expand Up @@ -91,13 +106,13 @@ def __init__(
self.save_hyperparameters("sincnet", "lstm", "linear")

self.sincnet = SincNet(**self.hparams.sincnet)



monolithic = lstm["monolithic"]
if monolithic:
multi_layer_lstm = dict(lstm)
del multi_layer_lstm["monolithic"]
self.lstm = nn.LSTM(60, **multi_layer_lstm)

self.lstm = nn.LSTM(512, **multi_layer_lstm)
else:
num_layers = lstm["num_layers"]
if num_layers > 1:
Expand All @@ -111,7 +126,7 @@ def __init__(
self.lstm = nn.ModuleList(
[
nn.LSTM(
60
512
if i == 0
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
**one_layer_lstm
Expand Down Expand Up @@ -167,22 +182,34 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
-------
scores : (batch, frame, classes)
"""

outputs = self.sincnet(waveforms)

#outputs = self.sincnet(waveforms)

#WavLM feature extraction

waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample)
with torch.no_grad():
feat = model(waveforms) #Compute the features and extract last hidden layer weights

outputs = feat.extract_features #Get the features : outputs : (batch, frame, feature)

if self.hparams.lstm["monolithic"]:
outputs, _ = self.lstm(
rearrange(outputs, "batch feature frame -> batch frame feature")
)
#No need to rearrange the output, as the features are already structured in (batch frame feature)

#outputs, _ = self.lstm(
# rearrange(outputs, "batch feature frame -> batch frame feature"))
outputs, _ = self.lstm(outputs)

else:
outputs = rearrange(outputs, "batch feature frame -> batch frame feature")
#outputs = rearrange(outputs, "batch feature frame -> batch frame feature").cuda()
for i, lstm in enumerate(self.lstm):
outputs, _ = lstm(outputs)
if i + 1 < self.hparams.lstm["num_layers"]:
outputs = self.dropout(outputs)



if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))

return self.activation(self.classifier(outputs))

0 comments on commit 6d3af2e

Please sign in to comment.