Skip to content

Commit

Permalink
Add model name and revision to args
Browse files Browse the repository at this point in the history
  • Loading branch information
ibevers committed May 28, 2024
1 parent 592bfb4 commit e3ac8f6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
9 changes: 6 additions & 3 deletions scripts/pyannote_31_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from datasets import load_dataset

from senselab.audio.tasks.preprocessing import resample_hf_dataset
from senselab.audio.tasks.pyannote_speaker_diarization_31 import (
pyannote_diarize_31,
from senselab.audio.tasks.pyannote_speaker_diarization import (
pyannote_diarize,
)
from senselab.utils.tasks.input_output import _from_hf_dataset_to_dict

Expand All @@ -26,7 +26,10 @@
print("Resampled dataset.")

print("Diarizing dataset...")
dataset_diarized = pyannote_diarize_31(dataset, batched=True, batch_size=2)
dataset_diarized = pyannote_diarize(dataset,
batched=True,
batch_size=2,
model_revision="3.1")
print("Diarized dataset.")

print(json.dumps(dataset_diarized, indent=4))
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Diarizes a dataset with Pyannote speaker diarization 3.1.
If it runs very quickly with little output after running once, delete the
cache and re-run.
cache and re-run. Not tested with models other than speaker diarization 3.1.
Pyannote speaker diarization 3.1:
https://huggingface.co/pyannote/speaker-diarization-3.1
Expand Down Expand Up @@ -39,21 +39,26 @@ def _annotation_to_dict(annotation: Annotation) -> List[Tuple]:
return dirization_list


def _pyannote_diarize_31_batch(
batch: Dataset, hf_token: Optional[str] = None
) -> Dict[str, Any]:
def _pyannote_diarize_batch(
batch: Dataset,
hf_token: Optional[str],
model_name: str,
model_revision: str
) -> Dict[str, Any]:
"""Diarize a batch of audio files using the Pyannote diarization model.
Args:
batch: A batch of audio files from a Hugging Face dataset.
hf_token: The Hugging Face API token.
model_name: The model name used.
model_revision: The model version used.
Returns:
A dictionary containing the diarizations for the batch.Becomes a
column in the dataset when returned.
"""
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
model_name + "-" + model_revision,
use_auth_token=hf_token)

diarizations = []
Expand All @@ -80,11 +85,15 @@ def _pyannote_diarize_31_batch(
return {'pyannote31_diarizations': diarizations}


def pyannote_diarize_31(dataset: Dict[str, Any],
hf_token: Optional[str] = None,
batched: bool = False,
batch_size: int = 1,
cache_path: str = "scripts/cache/") -> Dict[str, Any]:
def pyannote_diarize(
dataset: Dict[str, Any],
hf_token: Optional[str] = None,
batched: bool = False,
batch_size: int = 1,
cache_path: str = "scripts/cache/",
model_name: str = "pyannote/speaker-diarization",
model_revision: str = "3.1", # 2.0, 3.0, or 3.1
) -> Dict[str, Any]:
"""Diarizes the audio files in a Hugging Face dataset.
The diarizations are
Expand All @@ -97,6 +106,8 @@ def pyannote_diarize_31(dataset: Dict[str, Any],
batch_size: Number of samples to process in each batch,
if batching is enabled.
cache_path: The path to the cache directory.
model_name: The model name used.
model_revision: The model version used.
Returns:
hf_dataset_diarized: The dataset with an added column
Expand All @@ -107,7 +118,9 @@ def pyannote_diarize_31(dataset: Dict[str, Any],

hf_dataset = _from_dict_to_hf_dataset(dataset)
hf_dataset_diarized = hf_dataset.map(
lambda x: _pyannote_diarize_31_batch(x, hf_token),
lambda x: _pyannote_diarize_batch(x, hf_token,
model_name,
model_revision),
batched=batched,
batch_size=batch_size,
cache_file_name=cache_path + "pyannote31_cache",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

import pydra

from senselab.audio.tasks.pyannote_speaker_diarization_31 import pyannote_diarize_31
from senselab.audio.tasks.pyannote_speaker_diarization import pyannote_diarize

transcribe_dataset_with_hf_pt = pydra.mark.task(pyannote_diarize_31)
transcribe_dataset_with_hf_pt = pydra.mark.task(pyannote_diarize)

0 comments on commit e3ac8f6

Please sign in to comment.