generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #152 from sensein/implement_der
Implement diaraization error rate
- Loading branch information
Showing
3 changed files
with
119 additions
and
0 deletions.
There are no files selected for viewing
3 changes: 3 additions & 0 deletions
3
src/senselab/audio/tasks/speaker_diarization_evaluation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Top level file exposing speaker diarization utility functions.""" | ||
|
||
from .utils import calculate_diarization_error_rate # noqa: F401 |
55 changes: 55 additions & 0 deletions
55
src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Defines utility functions for evaluating speaker diarization results.""" | ||
|
||
from typing import Dict, List | ||
|
||
from pyannote.core import Annotation, Segment | ||
from pyannote.metrics.diarization import DiarizationErrorRate, GreedyDiarizationErrorRate | ||
|
||
from senselab.utils.data_structures.script_line import ScriptLine | ||
|
||
|
||
def calculate_diarization_error_rate( | ||
hypothesis: List[ScriptLine], | ||
reference: List[ScriptLine], | ||
greedy: bool = False, | ||
return_speaker_mapping: bool = False, | ||
detailed: bool = False, | ||
) -> Dict: | ||
"""Computes the diarization error rate (DER). | ||
Diarizztion error rate is the ratio of the sum of the false alarms (when speech is detected but none is there), | ||
missed detections (when speech is there but not detected), and speaker confusions (when speech is | ||
attributed to the wrong speaker) to the total ground truth time spoken. For more details see: | ||
https://docs.kolena.com/metrics/diarization-error-rate/ | ||
Args: | ||
hypothesis (List[ScriptLine]): the diarization generated as the result from a model | ||
reference (List[ScriptLine]): annotations that serve as the ground truth diarization | ||
greedy (bool): whether to use a greedy speaker mapping vs. one that optimizes for minimizing the confusion | ||
return_speaker_mapping (bool): return the mapping between speakers in the reference and the hypothesis | ||
detailed (bool): whether to include each component that contributed to the overall diarization error rate | ||
Returns: | ||
A dictionary with at least the diarization error rate, its components if detailed was true, and the | ||
speaker mapping if return_speaker_mapping was given. | ||
""" | ||
hypothesis_annotation = Annotation() | ||
reference_annotation = Annotation() | ||
|
||
for line in hypothesis: | ||
assert line.speaker | ||
hypothesis_annotation[Segment(line.start, line.end)] = line.speaker | ||
|
||
for line in reference: | ||
assert line.speaker | ||
reference_annotation[Segment(line.start, line.end)] = line.speaker | ||
|
||
metric = GreedyDiarizationErrorRate() if greedy else DiarizationErrorRate() | ||
der = metric(reference_annotation, hypothesis_annotation, detailed=detailed) | ||
output = {"diarization error rate": der} if not detailed else der | ||
if return_speaker_mapping: | ||
mapping_fn = metric.greedy_mapping if greedy else metric.optimal_mapping | ||
speaker_mapping = mapping_fn(reference_annotation, hypothesis_annotation) | ||
output["speaker_mapping"] = speaker_mapping | ||
|
||
return output |
61 changes: 61 additions & 0 deletions
61
src/tests/audio/tasks/speaker_diarization_evaluation_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Testing functions for speaker diarization evaluations.""" | ||
|
||
from senselab.audio.tasks.speaker_diarization_evaluation import calculate_diarization_error_rate | ||
from senselab.utils.data_structures.script_line import ScriptLine | ||
|
||
|
||
def test_diarization_error_rate_non_existent_speaker() -> None: | ||
"""Tests speaker diarization error rate when a non-existent speaker is found. | ||
This example diarization can be found here: | ||
https://docs.kolena.com/metrics/diarization-error-rate/#implementation-details | ||
""" | ||
ground_truth = [ | ||
ScriptLine(start=0, end=10, speaker="A"), | ||
ScriptLine(start=13, end=21, speaker="B"), | ||
ScriptLine(start=24, end=32, speaker="A"), | ||
ScriptLine(start=32, end=40, speaker="C"), | ||
] | ||
|
||
inference = [ | ||
ScriptLine(start=2, end=14, speaker="A"), | ||
ScriptLine(start=14, end=15, speaker="C"), | ||
ScriptLine(start=15, end=20, speaker="B"), | ||
ScriptLine(start=23, end=36, speaker="C"), | ||
ScriptLine(start=36, end=40, speaker="D"), | ||
] | ||
|
||
speaker_mapping = {"A": "A", "B": "B", "C": "C"} | ||
|
||
diarization = calculate_diarization_error_rate(inference, ground_truth, return_speaker_mapping=True, detailed=True) | ||
assert isinstance(diarization, dict) | ||
assert diarization["false alarm"] == 4 | ||
assert diarization["missed detection"] == 3 | ||
assert diarization["confusion"] == 14 | ||
assert diarization["speaker_mapping"] == speaker_mapping | ||
|
||
|
||
def test_diarization_error_rate_undetected_speaker() -> None: | ||
"""Tests speaker diarization error rate when a speaker goes undetected. | ||
This example diarization can be found here: | ||
https://docs.kolena.com/metrics/diarization-error-rate/#example | ||
""" | ||
ground_truth = [ | ||
ScriptLine(start=0, end=5, speaker="C"), | ||
ScriptLine(start=5, end=9, speaker="D"), | ||
ScriptLine(start=10, end=14, speaker="A"), | ||
ScriptLine(start=14, end=15, speaker="D"), | ||
ScriptLine(start=17, end=20, speaker="C"), | ||
ScriptLine(start=22, end=25, speaker="B"), | ||
] | ||
|
||
inference = [ | ||
ScriptLine(start=0, end=8, speaker="C"), | ||
ScriptLine(start=11, end=15, speaker="A"), | ||
ScriptLine(start=17, end=21, speaker="C"), | ||
ScriptLine(start=23, end=25, speaker="B"), | ||
] | ||
|
||
diarization = calculate_diarization_error_rate(inference, ground_truth, greedy=True) | ||
assert diarization["diarization error rate"] == 0.4 |