Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix instantaneous speaker count exceeding max_speakers or detected number of clusters #1351

Merged
merged 11 commits into from
Nov 16, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

- fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization`
- fix(pipeline): fix `AgglomerativeClustering` to honor `num_clusters` when provided
- fix(pipeline): fix frame-wise speaker count exceeding `max_speakers` or detected `num_speakers` in `SpeakerDiarization` pipeline

### Improvements

Expand All @@ -26,6 +27,8 @@
- BREAKING(setup): remove `onnxruntime` dependency.
You can still use ONNX `hbredin/wespeaker-voxceleb-resnet34-LM` but you will have to install `onnxruntime` yourself.
- BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead)
- BREAKING(pipeline): remove `onset` and `offset` parameter in `SpeakerDiarizationMixin.speaker_count`
You should now binarize segmentations before passing them to `speaker_count`

## Version 3.0.1 (2023-09-28)

Expand Down
1 change: 0 additions & 1 deletion pyannote/audio/pipelines/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def __call__(
hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8)
soft_clusters = np.ones((num_chunks, num_speakers, 1))
centroids = np.mean(train_embeddings, axis=0, keepdims=True)

return hard_clusters, soft_clusters, centroids

train_clusters = self.cluster(
Expand Down
11 changes: 9 additions & 2 deletions pyannote/audio/pipelines/resegmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_model,
)
from pyannote.audio.utils.permutation import mae_cost_func, permutate
from pyannote.audio.utils.signal import binarize


class Resegmentation(SpeakerDiarizationMixin, Pipeline):
Expand Down Expand Up @@ -181,11 +182,17 @@ def apply(

hook("segmentation", segmentations)

# estimate frame-level number of instantaneous speakers
count = self.speaker_count(
# binarize segmentations before speaker counting
binarized_segmentations: SlidingWindowFeature = binarize(
segmentations,
onset=self.onset,
offset=self.offset,
initial_state=False,
)

# estimate frame-level number of instantaneous speakers
count = self.speaker_count(
binarized_segmentations,
warm_up=(self.warm_up, self.warm_up),
frames=self._frames,
)
Expand Down
65 changes: 46 additions & 19 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import functools
import itertools
import math
import textwrap
import warnings
from typing import Callable, Optional, Text, Union

import numpy as np
Expand Down Expand Up @@ -478,12 +480,19 @@ def apply(
hook("segmentation", segmentations)
# shape: (num_chunks, num_frames, local_num_speakers)

# binarize segmentation
if self._segmentation.model.specifications.powerset:
binarized_segmentations = segmentations
else:
binarized_segmentations: SlidingWindowFeature = binarize(
segmentations,
onset=self.segmentation.threshold,
initial_state=False,
)

# estimate frame-level number of instantaneous speakers
count = self.speaker_count(
segmentations,
onset=0.5
if self._segmentation.model.specifications.powerset
else self.segmentation.threshold,
binarized_segmentations,
frames=self._frames,
warm_up=(0.0, 0.0),
)
Expand All @@ -499,16 +508,6 @@ def apply(

return diarization

# binarize segmentation
if self._segmentation.model.specifications.powerset:
binarized_segmentations = segmentations
else:
binarized_segmentations: SlidingWindowFeature = binarize(
segmentations,
onset=self.segmentation.threshold,
initial_state=False,
)

if self.klustering == "OracleClustering" and not return_embeddings:
embeddings = None
else:
Expand All @@ -533,6 +532,27 @@ def apply(
# hard_clusters: (num_chunks, num_speakers)
# centroids: (num_speakers, dimension)

# number of detected clusters is the number of different speakers
num_different_speakers = np.max(hard_clusters) + 1

# detected number of speakers can still be out of bounds
# (specifically, lower than `min_speakers`), since there could be too few embeddings
# to make enough clusters with a given minimum cluster size.
if num_different_speakers < min_speakers or num_different_speakers > max_speakers:
warnings.warn(textwrap.dedent(
f"""
The detected number of speakers ({num_different_speakers}) is outside
the given bounds [{min_speakers}, {max_speakers}]. This can happen if the
given audio file is too short to contain {min_speakers} or more speakers.
Try to lower the desired minimal number of speakers.
"""
))

# during counting, we could possibly overcount the number of instantaneous
# speakers due to segmentation errors, so we cap the maximum instantaneous number
# of speakers by the `max_speakers` value
count.data = np.minimum(count.data, max_speakers).astype(np.int8)

# reconstruct discrete diarization from raw hard clusters

# keep track of inactive speakers
Expand Down Expand Up @@ -588,18 +608,25 @@ def apply(
if not return_embeddings:
return diarization

# this can happen when we use OracleClustering
if centroids is None:
return diarization, None

# The number of centroids may be smaller than the number of speakers
# in the annotation. This can happen if the number of active speakers
# obtained from `speaker_count` for some frames is larger than the number
# of clusters obtained from `clustering`. In this case, we append zero embeddings
# for extra speakers
if len(diarization.labels()) > centroids.shape[0]:
centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)))
hbredin marked this conversation as resolved.
Show resolved Hide resolved

# re-order centroids so that they match
# the order given by diarization.labels()
inverse_mapping = {label: index for index, label in mapping.items()}
centroids = centroids[
[inverse_mapping[label] for label in diarization.labels()]
]

# FIXME: the number of centroids may be smaller than the number of speakers
# in the annotation. This can happen if the number of active speakers
# obtained from `speaker_count` for some frames is larger than the number
# of clusters obtained from `clustering`. Will be fixed in the future

return diarization, centroids

def get_metric(self) -> GreedyDiarizationErrorRate:
Expand Down
18 changes: 4 additions & 14 deletions pyannote/audio/pipelines/utils/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,26 +117,19 @@ def optimal_mapping(
else:
return mapped_hypothesis

# TODO: get rid of onset/offset (binarization should be applied before calling speaker_count)
# TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count)
@staticmethod
def speaker_count(
segmentations: SlidingWindowFeature,
onset: float = 0.5,
offset: float = None,
binarized_segmentations: SlidingWindowFeature,
warm_up: Tuple[float, float] = (0.1, 0.1),
frames: SlidingWindow = None,
) -> SlidingWindowFeature:
"""Estimate frame-level number of instantaneous speakers

Parameters
----------
segmentations : SlidingWindowFeature
(num_chunks, num_frames, num_classes)-shaped scores.
onset : float, optional
Onset threshold. Defaults to 0.5
offset : float, optional
Offset threshold. Defaults to `onset`.
binarized_segmentations : SlidingWindowFeature
(num_chunks, num_frames, num_classes)-shaped binarized scores.
warm_up : (float, float) tuple, optional
Left/right warm up ratio of chunk duration.
Defaults to (0.1, 0.1), i.e. 10% on both sides.
Expand All @@ -151,10 +144,7 @@ def speaker_count(
(num_frames, 1)-shaped instantaneous speaker count
"""

binarized: SlidingWindowFeature = binarize(
segmentations, onset=onset, offset=offset, initial_state=False
)
trimmed = Inference.trim(binarized, warm_up=warm_up)
trimmed = Inference.trim(binarized_segmentations, warm_up=warm_up)
count = Inference.aggregate(
np.sum(trimmed, axis=-1, keepdims=True),
frames=frames,
Expand Down
Loading