Skip to content

Commit

Permalink
feat(model) : add segmentation model based on self-supervised represe…
Browse files Browse the repository at this point in the history
…ntation (#1362)

Co-authored-by: Hervé BREDIN <[email protected]>
  • Loading branch information
SevKod and hbredin committed Sep 20, 2023
1 parent 6740db2 commit b9548a7
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 6 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@

### Features and improvements

- feat(task): add [powerset](https://www.isca-speech.org/archive/interspeech_2023/plaquet23_interspeech.html) support to `SpeakerDiarization` task
- feat(task): add support for multi-task models
- feat(task): add support for label scope in speaker diarization task
- feat(task): add support for missing classes in multi-label segmentation task
- feat(model): add segmentation model based on torchaudio self-supervised representation
- feat(pipeline): send pipeline to device with `pipeline.to(device)`
- feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`)
- feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task
- feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline
- feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`)
- feat(pipeline): add progress hook to pipelines
- feat(pipeline): check version compatibility at load time
- feat(task): add support for label scope in speaker diarization task
- feat(task): add support for missing classes in multi-label segmentation task
- improve(task): load metadata as tensors rather than pyannote.core instances
- improve(task): improve error message on missing specifications

Expand Down
13 changes: 13 additions & 0 deletions pyannote/audio/cli/train_config/model/SSeRiouSS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _group_
_target_: pyannote.audio.models.segmentation.SSeRiouSS
wav2vec: WAVLM_BASE
wav2vec_layer: -1
lstm:
hidden_size: 128
num_layers: 4
bidirectional: true
monolithic: true
dropout: 0.5
linear:
hidden_size: 128
num_layers: 2
234 changes: 234 additions & 0 deletions pyannote/audio/models/segmentation/SSeRiouSS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# MIT License
#
# Copyright (c) 2023- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from pyannote.core.utils.generators import pairwise

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.utils.params import merge_dict


class SSeRiouSS(Model):
"""Self-Supervised Representation for Speaker Segmentation
wav2vec > LSTM > Feed forward > Classifier
Parameters
----------
sample_rate : int, optional
Audio sample rate. Defaults to 16kHz (16000).
num_channels : int, optional
Number of channels. Defaults to mono (1).
wav2vec: dict or str, optional
Defaults to "WAVLM_BASE".
wav2vec_layer: int, optional
Index of layer to use as input to the LSTM.
Defaults (-1) to use average of all layers (with learnable weights).
lstm : dict, optional
Keyword arguments passed to the LSTM layer.
Defaults to {"hidden_size": 128, "num_layers": 4, "bidirectional": True},
i.e. two bidirectional layers with 128 units each.
Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs.
This may proove useful for probing LSTM internals.
linear : dict, optional
Keyword arugments used to initialize linear layers
Defaults to {"hidden_size": 128, "num_layers": 2},
i.e. two linear layers with 128 units each.
"""

WAV2VEC_DEFAULTS = "WAVLM_BASE"

LSTM_DEFAULTS = {
"hidden_size": 128,
"num_layers": 4,
"bidirectional": True,
"monolithic": True,
"dropout": 0.0,
}
LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2}

def __init__(
self,
wav2vec: Union[dict, str] = None,
wav2vec_layer: int = -1,
lstm: dict = None,
linear: dict = None,
sample_rate: int = 16000,
num_channels: int = 1,
task: Optional[Task] = None,
):
super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)

if isinstance(wav2vec, str):
# `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE")
if hasattr(torchaudio.pipelines, wav2vec):
bundle = getattr(torchaudio.pipelines, wav2vec)
if sample_rate != bundle._sample_rate:
raise ValueError(
f"Expected {bundle._sample_rate}Hz, found {sample_rate}Hz."
)
wav2vec_dim = bundle._params["encoder_embed_dim"]
wav2vec_num_layers = bundle._params["encoder_num_layers"]
self.wav2vec = bundle.get_model()

# `wav2vec` is a path to a self-supervised representation checkpoint
else:
_checkpoint = torch.load(wav2vec)
wav2vec = _checkpoint.pop("config")
self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec)
state_dict = _checkpoint.pop("state_dict")
self.wav2vec.load_state_dict(state_dict)
wav2vec_dim = wav2vec["encoder_embed_dim"]
wav2vec_num_layers = wav2vec["encoder_num_layers"]

# `wav2vec` is a config dictionary understood by `wav2vec2_model`
# this branch is typically used by Model.from_pretrained(...)
elif isinstance(wav2vec, dict):
self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec)
wav2vec_dim = wav2vec["encoder_embed_dim"]
wav2vec_num_layers = wav2vec["encoder_num_layers"]

if wav2vec_layer < 0:
self.wav2vec_weights = nn.Parameter(
data=torch.ones(wav2vec_num_layers), requires_grad=True
)

lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
lstm["batch_first"] = True
linear = merge_dict(self.LINEAR_DEFAULTS, linear)

self.save_hyperparameters("wav2vec", "wav2vec_layer", "lstm", "linear")

monolithic = lstm["monolithic"]
if monolithic:
multi_layer_lstm = dict(lstm)
del multi_layer_lstm["monolithic"]
self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm)

else:
num_layers = lstm["num_layers"]
if num_layers > 1:
self.dropout = nn.Dropout(p=lstm["dropout"])

one_layer_lstm = dict(lstm)
one_layer_lstm["num_layers"] = 1
one_layer_lstm["dropout"] = 0.0
del one_layer_lstm["monolithic"]

self.lstm = nn.ModuleList(
[
nn.LSTM(
wav2vec_dim
if i == 0
else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1),
**one_layer_lstm,
)
for i in range(num_layers)
]
)

if linear["num_layers"] < 1:
return

lstm_out_features: int = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)
self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
for in_features, out_features in pairwise(
[
lstm_out_features,
]
+ [self.hparams.linear["hidden_size"]]
* self.hparams.linear["num_layers"]
)
]
)

def build(self):
if self.hparams.linear["num_layers"] > 0:
in_features = self.hparams.linear["hidden_size"]
else:
in_features = self.hparams.lstm["hidden_size"] * (
2 if self.hparams.lstm["bidirectional"] else 1
)

if isinstance(self.specifications, tuple):
raise ValueError("SSeRiouSS model does not support multi-tasking.")

if self.specifications.powerset:
out_features = self.specifications.num_powerset_classes
else:
out_features = len(self.specifications.classes)

self.classifier = nn.Linear(in_features, out_features)
self.activation = self.default_activation()

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Parameters
----------
waveforms : (batch, channel, sample)
Returns
-------
scores : (batch, frame, classes)
"""

num_layers = (
None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer
)

with torch.no_grad():
outputs, _ = self.wav2vec.extract_features(
waveforms.squeeze(1), num_layers=num_layers
)

if num_layers is None:
outputs = torch.stack(outputs, dim=-1) @ F.softmax(
self.wav2vec_weights, dim=0
)
else:
outputs = outputs[-1]

if self.hparams.lstm["monolithic"]:
outputs, _ = self.lstm(outputs)
else:
for i, lstm in enumerate(self.lstm):
outputs, _ = lstm(outputs)
if i + 1 < self.hparams.lstm["num_layers"]:
outputs = self.dropout(outputs)

if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))

return self.activation(self.classifier(outputs))
5 changes: 3 additions & 2 deletions pyannote/audio/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2020 CNRS
# Copyright (c) 2020- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand All @@ -21,5 +21,6 @@
# SOFTWARE.

from .PyanNet import PyanNet
from .SSeRiouSS import SSeRiouSS

__all__ = ["PyanNet"]
__all__ = ["PyanNet", "SSeRiouSS"]

0 comments on commit b9548a7

Please sign in to comment.