diff --git a/CHANGELOG.md b/CHANGELOG.md index 6feac98ae..19e25f36e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization` - fix(pipeline): fix `AgglomerativeClustering` to honor `num_clusters` when provided +- fix(pipeline): fix frame-wise speaker count exceeding `max_speakers` or detected `num_speakers` in `SpeakerDiarization` pipeline ### Improvements @@ -26,6 +27,8 @@ - BREAKING(setup): remove `onnxruntime` dependency. You can still use ONNX `hbredin/wespeaker-voxceleb-resnet34-LM` but you will have to install `onnxruntime` yourself. - BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead) +- BREAKING(pipeline): remove `onset` and `offset` parameter in `SpeakerDiarizationMixin.speaker_count` + You should now binarize segmentations before passing them to `speaker_count` ## Version 3.0.1 (2023-09-28) diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index c51cdcc50..b63ab214f 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -253,7 +253,6 @@ def __call__( hard_clusters = np.zeros((num_chunks, num_speakers), dtype=np.int8) soft_clusters = np.ones((num_chunks, num_speakers, 1)) centroids = np.mean(train_embeddings, axis=0, keepdims=True) - return hard_clusters, soft_clusters, centroids train_clusters = self.cluster( diff --git a/pyannote/audio/pipelines/resegmentation.py b/pyannote/audio/pipelines/resegmentation.py index bb71abf22..d01e5d65f 100644 --- a/pyannote/audio/pipelines/resegmentation.py +++ b/pyannote/audio/pipelines/resegmentation.py @@ -39,6 +39,7 @@ get_model, ) from pyannote.audio.utils.permutation import mae_cost_func, permutate +from pyannote.audio.utils.signal import binarize class Resegmentation(SpeakerDiarizationMixin, Pipeline): @@ -181,11 +182,17 @@ def apply( hook("segmentation", segmentations) - # estimate frame-level number of instantaneous speakers - count = self.speaker_count( + # binarize segmentations before speaker counting + binarized_segmentations: SlidingWindowFeature = binarize( segmentations, onset=self.onset, offset=self.offset, + initial_state=False, + ) + + # estimate frame-level number of instantaneous speakers + count = self.speaker_count( + binarized_segmentations, warm_up=(self.warm_up, self.warm_up), frames=self._frames, ) diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index d5cf04e05..354f6be7e 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -25,6 +25,8 @@ import functools import itertools import math +import textwrap +import warnings from typing import Callable, Optional, Text, Union import numpy as np @@ -478,12 +480,19 @@ def apply( hook("segmentation", segmentations) # shape: (num_chunks, num_frames, local_num_speakers) + # binarize segmentation + if self._segmentation.model.specifications.powerset: + binarized_segmentations = segmentations + else: + binarized_segmentations: SlidingWindowFeature = binarize( + segmentations, + onset=self.segmentation.threshold, + initial_state=False, + ) + # estimate frame-level number of instantaneous speakers count = self.speaker_count( - segmentations, - onset=0.5 - if self._segmentation.model.specifications.powerset - else self.segmentation.threshold, + binarized_segmentations, frames=self._frames, warm_up=(0.0, 0.0), ) @@ -499,16 +508,6 @@ def apply( return diarization - # binarize segmentation - if self._segmentation.model.specifications.powerset: - binarized_segmentations = segmentations - else: - binarized_segmentations: SlidingWindowFeature = binarize( - segmentations, - onset=self.segmentation.threshold, - initial_state=False, - ) - if self.klustering == "OracleClustering" and not return_embeddings: embeddings = None else: @@ -533,6 +532,27 @@ def apply( # hard_clusters: (num_chunks, num_speakers) # centroids: (num_speakers, dimension) + # number of detected clusters is the number of different speakers + num_different_speakers = np.max(hard_clusters) + 1 + + # detected number of speakers can still be out of bounds + # (specifically, lower than `min_speakers`), since there could be too few embeddings + # to make enough clusters with a given minimum cluster size. + if num_different_speakers < min_speakers or num_different_speakers > max_speakers: + warnings.warn(textwrap.dedent( + f""" + The detected number of speakers ({num_different_speakers}) is outside + the given bounds [{min_speakers}, {max_speakers}]. This can happen if the + given audio file is too short to contain {min_speakers} or more speakers. + Try to lower the desired minimal number of speakers. + """ + )) + + # during counting, we could possibly overcount the number of instantaneous + # speakers due to segmentation errors, so we cap the maximum instantaneous number + # of speakers by the `max_speakers` value + count.data = np.minimum(count.data, max_speakers).astype(np.int8) + # reconstruct discrete diarization from raw hard clusters # keep track of inactive speakers @@ -588,6 +608,18 @@ def apply( if not return_embeddings: return diarization + # this can happen when we use OracleClustering + if centroids is None: + return diarization, None + + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad(centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0))) + # re-order centroids so that they match # the order given by diarization.labels() inverse_mapping = {label: index for index, label in mapping.items()} @@ -595,11 +627,6 @@ def apply( [inverse_mapping[label] for label in diarization.labels()] ] - # FIXME: the number of centroids may be smaller than the number of speakers - # in the annotation. This can happen if the number of active speakers - # obtained from `speaker_count` for some frames is larger than the number - # of clusters obtained from `clustering`. Will be fixed in the future - return diarization, centroids def get_metric(self) -> GreedyDiarizationErrorRate: diff --git a/pyannote/audio/pipelines/utils/diarization.py b/pyannote/audio/pipelines/utils/diarization.py index 91413350b..4a35f7049 100644 --- a/pyannote/audio/pipelines/utils/diarization.py +++ b/pyannote/audio/pipelines/utils/diarization.py @@ -117,13 +117,10 @@ def optimal_mapping( else: return mapped_hypothesis - # TODO: get rid of onset/offset (binarization should be applied before calling speaker_count) # TODO: get rid of warm-up parameter (trimming should be applied before calling speaker_count) @staticmethod def speaker_count( - segmentations: SlidingWindowFeature, - onset: float = 0.5, - offset: float = None, + binarized_segmentations: SlidingWindowFeature, warm_up: Tuple[float, float] = (0.1, 0.1), frames: SlidingWindow = None, ) -> SlidingWindowFeature: @@ -131,12 +128,8 @@ def speaker_count( Parameters ---------- - segmentations : SlidingWindowFeature - (num_chunks, num_frames, num_classes)-shaped scores. - onset : float, optional - Onset threshold. Defaults to 0.5 - offset : float, optional - Offset threshold. Defaults to `onset`. + binarized_segmentations : SlidingWindowFeature + (num_chunks, num_frames, num_classes)-shaped binarized scores. warm_up : (float, float) tuple, optional Left/right warm up ratio of chunk duration. Defaults to (0.1, 0.1), i.e. 10% on both sides. @@ -151,10 +144,7 @@ def speaker_count( (num_frames, 1)-shaped instantaneous speaker count """ - binarized: SlidingWindowFeature = binarize( - segmentations, onset=onset, offset=offset, initial_state=False - ) - trimmed = Inference.trim(binarized, warm_up=warm_up) + trimmed = Inference.trim(binarized_segmentations, warm_up=warm_up) count = Inference.aggregate( np.sum(trimmed, axis=-1, keepdims=True), frames=frames,