diff --git a/scripts/pyannote_31_experiment.py b/scripts/pyannote_31_experiment.py index e8b3e5c2..84905b27 100644 --- a/scripts/pyannote_31_experiment.py +++ b/scripts/pyannote_31_experiment.py @@ -10,8 +10,8 @@ from datasets import load_dataset from senselab.audio.tasks.preprocessing import resample_hf_dataset -from senselab.audio.tasks.pyannote_speaker_diarization_31 import ( - pyannote_diarize_31, +from senselab.audio.tasks.pyannote_speaker_diarization import ( + pyannote_diarize, ) from senselab.utils.tasks.input_output import _from_hf_dataset_to_dict @@ -26,7 +26,10 @@ print("Resampled dataset.") print("Diarizing dataset...") -dataset_diarized = pyannote_diarize_31(dataset, batched=True, batch_size=2) +dataset_diarized = pyannote_diarize(dataset, + batched=True, + batch_size=2, + model_revision="3.1") print("Diarized dataset.") print(json.dumps(dataset_diarized, indent=4)) diff --git a/src/senselab/audio/tasks/pyannote_speaker_diarization_31.py b/src/senselab/audio/tasks/pyannote_speaker_diarization.py similarity index 78% rename from src/senselab/audio/tasks/pyannote_speaker_diarization_31.py rename to src/senselab/audio/tasks/pyannote_speaker_diarization.py index 7165e33f..4d1a0971 100644 --- a/src/senselab/audio/tasks/pyannote_speaker_diarization_31.py +++ b/src/senselab/audio/tasks/pyannote_speaker_diarization.py @@ -1,7 +1,7 @@ """Diarizes a dataset with Pyannote speaker diarization 3.1. If it runs very quickly with little output after running once, delete the -cache and re-run. +cache and re-run. Not tested with models other than speaker diarization 3.1. Pyannote speaker diarization 3.1: https://huggingface.co/pyannote/speaker-diarization-3.1 @@ -39,21 +39,26 @@ def _annotation_to_dict(annotation: Annotation) -> List[Tuple]: return dirization_list -def _pyannote_diarize_31_batch( - batch: Dataset, hf_token: Optional[str] = None - ) -> Dict[str, Any]: +def _pyannote_diarize_batch( + batch: Dataset, + hf_token: Optional[str], + model_name: str, + model_revision: str +) -> Dict[str, Any]: """Diarize a batch of audio files using the Pyannote diarization model. Args: batch: A batch of audio files from a Hugging Face dataset. hf_token: The Hugging Face API token. + model_name: The model name used. + model_revision: The model version used. Returns: A dictionary containing the diarizations for the batch.Becomes a column in the dataset when returned. """ pipeline = Pipeline.from_pretrained( - "pyannote/speaker-diarization-3.1", + model_name + "-" + model_revision, use_auth_token=hf_token) diarizations = [] @@ -80,11 +85,15 @@ def _pyannote_diarize_31_batch( return {'pyannote31_diarizations': diarizations} -def pyannote_diarize_31(dataset: Dict[str, Any], - hf_token: Optional[str] = None, - batched: bool = False, - batch_size: int = 1, - cache_path: str = "scripts/cache/") -> Dict[str, Any]: +def pyannote_diarize( + dataset: Dict[str, Any], + hf_token: Optional[str] = None, + batched: bool = False, + batch_size: int = 1, + cache_path: str = "scripts/cache/", + model_name: str = "pyannote/speaker-diarization", + model_revision: str = "3.1", # 2.0, 3.0, or 3.1 +) -> Dict[str, Any]: """Diarizes the audio files in a Hugging Face dataset. The diarizations are @@ -97,6 +106,8 @@ def pyannote_diarize_31(dataset: Dict[str, Any], batch_size: Number of samples to process in each batch, if batching is enabled. cache_path: The path to the cache directory. + model_name: The model name used. + model_revision: The model version used. Returns: hf_dataset_diarized: The dataset with an added column @@ -107,7 +118,9 @@ def pyannote_diarize_31(dataset: Dict[str, Any], hf_dataset = _from_dict_to_hf_dataset(dataset) hf_dataset_diarized = hf_dataset.map( - lambda x: _pyannote_diarize_31_batch(x, hf_token), + lambda x: _pyannote_diarize_batch(x, hf_token, + model_name, + model_revision), batched=batched, batch_size=batch_size, cache_file_name=cache_path + "pyannote31_cache", diff --git a/src/senselab/audio/tasks/pyannote_speaker_diarization_31_pydra.py b/src/senselab/audio/tasks/pyannote_speaker_diarization_pydra.py similarity index 61% rename from src/senselab/audio/tasks/pyannote_speaker_diarization_31_pydra.py rename to src/senselab/audio/tasks/pyannote_speaker_diarization_pydra.py index 35380586..34c39540 100644 --- a/src/senselab/audio/tasks/pyannote_speaker_diarization_31_pydra.py +++ b/src/senselab/audio/tasks/pyannote_speaker_diarization_pydra.py @@ -2,6 +2,6 @@ import pydra -from senselab.audio.tasks.pyannote_speaker_diarization_31 import pyannote_diarize_31 +from senselab.audio.tasks.pyannote_speaker_diarization import pyannote_diarize -transcribe_dataset_with_hf_pt = pydra.mark.task(pyannote_diarize_31) \ No newline at end of file +transcribe_dataset_with_hf_pt = pydra.mark.task(pyannote_diarize) \ No newline at end of file