Skip to content

Commit

Permalink
Fall back on k-means clustering if the agglomerative clustering canno…
Browse files Browse the repository at this point in the history
…t be done correctly with the given constraints on number of speakers
  • Loading branch information
flyingleafe committed May 9, 2023
1 parent d5a14f7 commit 3124b28
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
67 changes: 44 additions & 23 deletions pyannote/audio/pipelines/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@


import random
import warnings
import textwrap
from enum import Enum
from typing import Tuple
from typing import Tuple, Optional

import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange
from pyannote.core import SlidingWindow, SlidingWindowFeature
from pyannote.pipeline import Pipeline
from pyannote.pipeline.parameter import Categorical, Integer, Uniform
from scipy.cluster.vq import kmeans2
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
Expand All @@ -57,12 +61,20 @@ def __init__(
def set_num_clusters(
self,
num_embeddings: int,
num_clusters: int = None,
min_clusters: int = None,
max_clusters: int = None,
num_clusters: Optional[int] = None,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
):

min_clusters = num_clusters or min_clusters or 1

if num_embeddings < min_clusters:
warnings.warn(textwrap.dedent(f"""\
Number of provided embeddings ({num_embeddings}) is smaller than
the minimum number of clusters requested ({min_clusters}).
Impossible to respect the `min_clusters` constraint.
"""
));

min_clusters = max(1, min(num_embeddings, min_clusters))
max_clusters = num_clusters or max_clusters or num_embeddings
max_clusters = max(1, min(num_embeddings, max_clusters))
Expand All @@ -81,7 +93,7 @@ def set_num_clusters(
def filter_embeddings(
self,
embeddings: np.ndarray,
segmentations: SlidingWindowFeature = None,
segmentations: Optional[SlidingWindowFeature] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Filter NaN embeddings and downsample embeddings
Expand Down Expand Up @@ -199,12 +211,12 @@ def assign_embeddings(
def __call__(
self,
embeddings: np.ndarray,
segmentations: SlidingWindowFeature = None,
num_clusters: int = None,
min_clusters: int = None,
max_clusters: int = None,
segmentations: Optional[SlidingWindowFeature] = None,
num_clusters: Optional[int] = None,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
**kwargs,
) -> np.ndarray:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Apply clustering
Parameters
Expand Down Expand Up @@ -329,7 +341,7 @@ def cluster(
embeddings: np.ndarray,
min_clusters: int,
max_clusters: int,
num_clusters: int = None,
num_clusters: Optional[int] = None,
):
"""
Expand All @@ -356,7 +368,7 @@ def cluster(
# heuristic to reduce self.min_cluster_size when num_embeddings is very small
# (0.1 value is kind of arbitrary, though)
min_cluster_size = min(
self.min_cluster_size, max(1, round(0.1 * num_embeddings))
self.min_cluster_size, max(1, min(round(0.1 * num_embeddings), round(num_embeddings / min_clusters)))
)

# linkage function will complain when there is just one embedding to cluster
Expand All @@ -377,7 +389,7 @@ def cluster(
dendrogram: np.ndarray = linkage(
embeddings, method=self.method, metric=self.metric
)

# apply the predefined threshold
clusters = fcluster(dendrogram, self.threshold, criterion="distance") - 1

Expand Down Expand Up @@ -407,6 +419,9 @@ def cluster(
best_iteration = num_embeddings - 1
best_num_large_clusters = 1

# first, we are trying to find the clustering which still
# respects the `min_cluster_size` constraint

# traverse the dendrogram by going further and further away
# from the "optimal" threshold

Expand All @@ -429,24 +444,30 @@ def cluster(
if abs(num_large_clusters - num_clusters) < abs(
best_num_large_clusters - num_clusters
):
best_iteration = iteration
best_num_large_clusters = num_large_clusters

# stop traversing the dendrogram as soon as we found a good candidate
if num_large_clusters == num_clusters:
break

# re-apply best iteration in case we did not find a perfect candidate
# if we did not find the best candidate, fall back to k-means.
# TODO: can we do better?
if best_num_large_clusters != num_clusters:
clusters = (
fcluster(_dendrogram, best_iteration, criterion="distance") - 1
)
warnings.warn("Agglomerative clustering did not return appropriate result. Falling back to k-means.")

_, clusters = kmeans2(embeddings, num_clusters, minit="++", seed=42)
cluster_unique, cluster_counts = np.unique(clusters, return_counts=True)
large_clusters = cluster_unique[cluster_counts >= min_cluster_size]

# all clusters then are large enough,
new_min_cluster_size = np.min(cluster_counts)
large_clusters = cluster_unique
num_large_clusters = len(large_clusters)
print(
f"Found only {num_large_clusters} clusters. Using a smaller value than {min_cluster_size} for `min_cluster_size` might help."
)

warnings.warn(textwrap.dedent(f"""\
Could not find a good candidate with `min_cluster_size` {min_cluster_size},
given the necessary number of clusters {num_clusters}. The new minimal cluster size is {new_min_cluster_size}."""
))
min_cluster_size = new_min_cluster_size

if num_large_clusters == 0:
clusters[:] = 0
Expand Down
7 changes: 6 additions & 1 deletion pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import functools
import itertools
import math
import warnings
from typing import Callable, Optional, Text, Union, Dict, Any

import numpy as np
Expand Down Expand Up @@ -495,7 +496,11 @@ def apply(
# number of detected clusters is the number of different speakers
num_different_speakers = centroids.shape[0]
# quick sanity check
assert num_different_speakers >= min_speakers and num_different_speakers <= max_speakers
if num_different_speakers < min_speakers or num_different_speakers > max_speakers:
warnings.warn(
f"Number of detected speakers ({num_different_speakers}) "
f"outside of [{min_speakers}, {max_speakers}] range"
)

# during counting, we could possibly overcount the number of instantaneous
# speakers due to segmentation errors, so we cap the maximum instantaneous number
Expand Down

0 comments on commit 3124b28

Please sign in to comment.