diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index 839011943..a906d4136 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -24,14 +24,18 @@ import random +import warnings +import textwrap from enum import Enum -from typing import Tuple +from typing import Tuple, Optional import numpy as np +import matplotlib.pyplot as plt from einops import rearrange from pyannote.core import SlidingWindow, SlidingWindowFeature from pyannote.pipeline import Pipeline from pyannote.pipeline.parameter import Categorical, Integer, Uniform +from scipy.cluster.vq import kmeans2 from scipy.cluster.hierarchy import fcluster, linkage from scipy.optimize import linear_sum_assignment from scipy.spatial.distance import cdist @@ -57,12 +61,20 @@ def __init__( def set_num_clusters( self, num_embeddings: int, - num_clusters: int = None, - min_clusters: int = None, - max_clusters: int = None, + num_clusters: Optional[int] = None, + min_clusters: Optional[int] = None, + max_clusters: Optional[int] = None, ): - min_clusters = num_clusters or min_clusters or 1 + + if num_embeddings < min_clusters: + warnings.warn(textwrap.dedent(f"""\ + Number of provided embeddings ({num_embeddings}) is smaller than + the minimum number of clusters requested ({min_clusters}). + Impossible to respect the `min_clusters` constraint. + """ + )); + min_clusters = max(1, min(num_embeddings, min_clusters)) max_clusters = num_clusters or max_clusters or num_embeddings max_clusters = max(1, min(num_embeddings, max_clusters)) @@ -81,7 +93,7 @@ def set_num_clusters( def filter_embeddings( self, embeddings: np.ndarray, - segmentations: SlidingWindowFeature = None, + segmentations: Optional[SlidingWindowFeature] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Filter NaN embeddings and downsample embeddings @@ -199,12 +211,12 @@ def assign_embeddings( def __call__( self, embeddings: np.ndarray, - segmentations: SlidingWindowFeature = None, - num_clusters: int = None, - min_clusters: int = None, - max_clusters: int = None, + segmentations: Optional[SlidingWindowFeature] = None, + num_clusters: Optional[int] = None, + min_clusters: Optional[int] = None, + max_clusters: Optional[int] = None, **kwargs, - ) -> np.ndarray: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Apply clustering Parameters @@ -329,7 +341,7 @@ def cluster( embeddings: np.ndarray, min_clusters: int, max_clusters: int, - num_clusters: int = None, + num_clusters: Optional[int] = None, ): """ @@ -356,7 +368,7 @@ def cluster( # heuristic to reduce self.min_cluster_size when num_embeddings is very small # (0.1 value is kind of arbitrary, though) min_cluster_size = min( - self.min_cluster_size, max(1, round(0.1 * num_embeddings)) + self.min_cluster_size, max(1, min(round(0.1 * num_embeddings), round(num_embeddings / min_clusters))) ) # linkage function will complain when there is just one embedding to cluster @@ -377,7 +389,7 @@ def cluster( dendrogram: np.ndarray = linkage( embeddings, method=self.method, metric=self.metric ) - + # apply the predefined threshold clusters = fcluster(dendrogram, self.threshold, criterion="distance") - 1 @@ -407,6 +419,9 @@ def cluster( best_iteration = num_embeddings - 1 best_num_large_clusters = 1 + # first, we are trying to find the clustering which still + # respects the `min_cluster_size` constraint + # traverse the dendrogram by going further and further away # from the "optimal" threshold @@ -429,24 +444,30 @@ def cluster( if abs(num_large_clusters - num_clusters) < abs( best_num_large_clusters - num_clusters ): - best_iteration = iteration best_num_large_clusters = num_large_clusters # stop traversing the dendrogram as soon as we found a good candidate if num_large_clusters == num_clusters: break - # re-apply best iteration in case we did not find a perfect candidate + # if we did not find the best candidate, fall back to k-means. + # TODO: can we do better? if best_num_large_clusters != num_clusters: - clusters = ( - fcluster(_dendrogram, best_iteration, criterion="distance") - 1 - ) + warnings.warn("Agglomerative clustering did not return appropriate result. Falling back to k-means.") + + _, clusters = kmeans2(embeddings, num_clusters, minit="++", seed=42) cluster_unique, cluster_counts = np.unique(clusters, return_counts=True) - large_clusters = cluster_unique[cluster_counts >= min_cluster_size] + + # all clusters then are large enough, + new_min_cluster_size = np.min(cluster_counts) + large_clusters = cluster_unique num_large_clusters = len(large_clusters) - print( - f"Found only {num_large_clusters} clusters. Using a smaller value than {min_cluster_size} for `min_cluster_size` might help." - ) + + warnings.warn(textwrap.dedent(f"""\ + Could not find a good candidate with `min_cluster_size` {min_cluster_size}, + given the necessary number of clusters {num_clusters}. The new minimal cluster size is {new_min_cluster_size}.""" + )) + min_cluster_size = new_min_cluster_size if num_large_clusters == 0: clusters[:] = 0 diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index 86869d979..45fbed294 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -25,6 +25,7 @@ import functools import itertools import math +import warnings from typing import Callable, Optional, Text, Union, Dict, Any import numpy as np @@ -495,7 +496,11 @@ def apply( # number of detected clusters is the number of different speakers num_different_speakers = centroids.shape[0] # quick sanity check - assert num_different_speakers >= min_speakers and num_different_speakers <= max_speakers + if num_different_speakers < min_speakers or num_different_speakers > max_speakers: + warnings.warn( + f"Number of detected speakers ({num_different_speakers}) " + f"outside of [{min_speakers}, {max_speakers}] range" + ) # during counting, we could possibly overcount the number of instantaneous # speakers due to segmentation errors, so we cap the maximum instantaneous number