Skip to content

Commit

Permalink
Merge pull request #152 from sensein/implement_der
Browse files Browse the repository at this point in the history
Implement diaraization error rate
  • Loading branch information
fabiocat93 authored Sep 12, 2024
2 parents a4c4560 + 34ed72a commit 861b090
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Top level file exposing speaker diarization utility functions."""

from .utils import calculate_diarization_error_rate # noqa: F401
55 changes: 55 additions & 0 deletions src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Defines utility functions for evaluating speaker diarization results."""

from typing import Dict, List

from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate, GreedyDiarizationErrorRate

from senselab.utils.data_structures.script_line import ScriptLine


def calculate_diarization_error_rate(
hypothesis: List[ScriptLine],
reference: List[ScriptLine],
greedy: bool = False,
return_speaker_mapping: bool = False,
detailed: bool = False,
) -> Dict:
"""Computes the diarization error rate (DER).
Diarizztion error rate is the ratio of the sum of the false alarms (when speech is detected but none is there),
missed detections (when speech is there but not detected), and speaker confusions (when speech is
attributed to the wrong speaker) to the total ground truth time spoken. For more details see:
https://docs.kolena.com/metrics/diarization-error-rate/
Args:
hypothesis (List[ScriptLine]): the diarization generated as the result from a model
reference (List[ScriptLine]): annotations that serve as the ground truth diarization
greedy (bool): whether to use a greedy speaker mapping vs. one that optimizes for minimizing the confusion
return_speaker_mapping (bool): return the mapping between speakers in the reference and the hypothesis
detailed (bool): whether to include each component that contributed to the overall diarization error rate
Returns:
A dictionary with at least the diarization error rate, its components if detailed was true, and the
speaker mapping if return_speaker_mapping was given.
"""
hypothesis_annotation = Annotation()
reference_annotation = Annotation()

for line in hypothesis:
assert line.speaker
hypothesis_annotation[Segment(line.start, line.end)] = line.speaker

for line in reference:
assert line.speaker
reference_annotation[Segment(line.start, line.end)] = line.speaker

metric = GreedyDiarizationErrorRate() if greedy else DiarizationErrorRate()
der = metric(reference_annotation, hypothesis_annotation, detailed=detailed)
output = {"diarization error rate": der} if not detailed else der
if return_speaker_mapping:
mapping_fn = metric.greedy_mapping if greedy else metric.optimal_mapping
speaker_mapping = mapping_fn(reference_annotation, hypothesis_annotation)
output["speaker_mapping"] = speaker_mapping

return output
61 changes: 61 additions & 0 deletions src/tests/audio/tasks/speaker_diarization_evaluation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Testing functions for speaker diarization evaluations."""

from senselab.audio.tasks.speaker_diarization_evaluation import calculate_diarization_error_rate
from senselab.utils.data_structures.script_line import ScriptLine


def test_diarization_error_rate_non_existent_speaker() -> None:
"""Tests speaker diarization error rate when a non-existent speaker is found.
This example diarization can be found here:
https://docs.kolena.com/metrics/diarization-error-rate/#implementation-details
"""
ground_truth = [
ScriptLine(start=0, end=10, speaker="A"),
ScriptLine(start=13, end=21, speaker="B"),
ScriptLine(start=24, end=32, speaker="A"),
ScriptLine(start=32, end=40, speaker="C"),
]

inference = [
ScriptLine(start=2, end=14, speaker="A"),
ScriptLine(start=14, end=15, speaker="C"),
ScriptLine(start=15, end=20, speaker="B"),
ScriptLine(start=23, end=36, speaker="C"),
ScriptLine(start=36, end=40, speaker="D"),
]

speaker_mapping = {"A": "A", "B": "B", "C": "C"}

diarization = calculate_diarization_error_rate(inference, ground_truth, return_speaker_mapping=True, detailed=True)
assert isinstance(diarization, dict)
assert diarization["false alarm"] == 4
assert diarization["missed detection"] == 3
assert diarization["confusion"] == 14
assert diarization["speaker_mapping"] == speaker_mapping


def test_diarization_error_rate_undetected_speaker() -> None:
"""Tests speaker diarization error rate when a speaker goes undetected.
This example diarization can be found here:
https://docs.kolena.com/metrics/diarization-error-rate/#example
"""
ground_truth = [
ScriptLine(start=0, end=5, speaker="C"),
ScriptLine(start=5, end=9, speaker="D"),
ScriptLine(start=10, end=14, speaker="A"),
ScriptLine(start=14, end=15, speaker="D"),
ScriptLine(start=17, end=20, speaker="C"),
ScriptLine(start=22, end=25, speaker="B"),
]

inference = [
ScriptLine(start=0, end=8, speaker="C"),
ScriptLine(start=11, end=15, speaker="A"),
ScriptLine(start=17, end=21, speaker="C"),
ScriptLine(start=23, end=25, speaker="B"),
]

diarization = calculate_diarization_error_rate(inference, ground_truth, greedy=True)
assert diarization["diarization error rate"] == 0.4

0 comments on commit 861b090

Please sign in to comment.