From dd470483ece18b56efa9436057c56176fe5aa132 Mon Sep 17 00:00:00 2001 From: Jan Profant Date: Tue, 30 Jan 2024 10:06:48 +0100 Subject: [PATCH] Initial version with dict-based annotations-segments --- pyannote/audio/core/task.py | 25 +++++++++++++++++++ .../audio/tasks/segmentation/multilabel.py | 5 ++-- .../overlapped_speech_detection.py | 4 +-- .../tasks/segmentation/speaker_diarization.py | 4 +-- .../segmentation/voice_activity_detection.py | 4 +-- 5 files changed, 30 insertions(+), 12 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 82e8939fe..d69f8ddad 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -642,6 +642,31 @@ def setup(self, stage=None): f"does not correspond to the cached one ({self.prepared_data['protocol']})" ) + # prepare annotations-segments into dict-like format, since it can't be stored in a cache .npy file like that + annotations = self.prepared_data['annotations-segments'] + annotations_dict = defaultdict(list) + file_ids = [] + for annotation in annotations: + file_id = annotation[0] + file_ids.append(file_id) + annotations_dict[file_id].append(annotation) + + segment_dtype = [ + ( + "file_id", + get_dtype(max(a[0] for a in annotations)), + ), + ("start", "f"), + ("end", "f"), + ("file_label_idx", get_dtype(max(a[3] for a in annotations))), + ("database_label_idx", get_dtype(max(a[4] for a in annotations))), + ("global_label_idx", get_dtype(max(a[5] for a in annotations))), + ] + + for file_id in file_ids: + annotations_dict[file_id] = np.array(annotations_dict[file_id], dtype=segment_dtype) + self.prepared_data['annotations-segments'] = annotations_dict + @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: # setup metadata on-demand the first time specifications are requested and missing diff --git a/pyannote/audio/tasks/segmentation/multilabel.py b/pyannote/audio/tasks/segmentation/multilabel.py index 66a28e7ba..d721b7825 100644 --- a/pyannote/audio/tasks/segmentation/multilabel.py +++ b/pyannote/audio/tasks/segmentation/multilabel.py @@ -281,10 +281,9 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample = dict() sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 7249ed0f4..4362bc722 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -172,9 +172,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 47c5adc63..503ca8612 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -343,9 +343,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[ diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 183fa2ffc..435c6e175 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -154,9 +154,7 @@ def prepare_chunk(self, file_id: int, start_time: float, duration: float): sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) # gather all annotations of current file - annotations = self.prepared_data["annotations-segments"][ - self.prepared_data["annotations-segments"]["file_id"] == file_id - ] + annotations = self.prepared_data["annotations-segments"][file_id] # gather all annotations with non-empty intersection with current chunk chunk_annotations = annotations[