From 1d70ae22bd577d52217d835420fb6cfb42eb9cc6 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Fri, 4 Oct 2024 13:49:20 +0200 Subject: [PATCH] fix `non_silent` index error --- pyannote/audio/pipelines/speech_separation.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py index c1b9b036c..5a06e5d55 100644 --- a/pyannote/audio/pipelines/speech_separation.py +++ b/pyannote/audio/pipelines/speech_separation.py @@ -91,7 +91,7 @@ class SpeechSeparation(SpeakerDiarizationMixin, Pipeline): Usage ----- - >>> pipeline = SpeakerDiarization() + >>> pipeline = SpeechSeparation() >>> diarization, separation = pipeline("/path/to/audio.wav") >>> diarization, separation = pipeline("/path/to/audio.wav", num_speakers=4) >>> diarization, separation = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) @@ -236,7 +236,7 @@ def get_segmentations( segmentations, separations = file[self.CACHED_SEGMENTATION] else: segmentations, separations = self._segmentation(file, hook=hook) - file[self.CACHED_SEGMENTATION] = segmentations + file[self.CACHED_SEGMENTATION] = (segmentations, separations) else: segmentations, separations = self._segmentation(file, hook=hook) @@ -441,7 +441,6 @@ def reconstruct( clustered_segmentations, segmentations.sliding_window ) return clustered_segmentations - return self.to_diarization(clustered_segmentations, count) def apply( self, @@ -583,7 +582,7 @@ def apply( # reconstruct discrete diarization from raw hard clusters - # keep track of inactive speakers + # keep track of inactive speakers at chunk level inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 # shape: (num_chunks, num_speakers) @@ -594,10 +593,18 @@ def apply( count, ) discrete_diarization = self.to_diarization(discrete_diarization, count) + + # remove inactive speakers at the audio level from the diarization + active_speakers = np.sum(discrete_diarization, axis=0) > 0 + # shape: (num_speakers, ) + discrete_diarization.data = discrete_diarization.data[:, active_speakers] + # shape: (num_frames, num_active_speakers) + hook("discrete_diarization", discrete_diarization) clustered_separations = self.reconstruct(separations, hard_clusters, count) frame_duration = separations.sliding_window.duration / separations.data.shape[1] frames = SlidingWindow(step=frame_duration, duration=2 * frame_duration) + sources = Inference.aggregate( clustered_separations, frames=frames, @@ -605,6 +612,8 @@ def apply( missing=0.0, skip_average=True, ) + sources.data = sources.data[:, active_speakers] + # zero-out sources when speaker is inactive # WARNING: this should be rewritten to avoid huge memory consumption if self.separation.leakage_removal: