generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
src/senselab/audio/tasks/speaker_diarization_evaluation/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Top level file exposing speaker diarization utility functions.""" | ||
|
||
from .utils import calculate_diarization_error_rate # noqa: F401 |
129 changes: 129 additions & 0 deletions
129
src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Defines utility functions for evaluating speaker diarization results.""" | ||
|
||
from typing import Dict, List | ||
|
||
from interval import interval as pyinterval | ||
|
||
from senselab.utils.data_structures.script_line import ScriptLine | ||
|
||
|
||
def calculate_diarization_error_rate( | ||
inference: List[ScriptLine], ground_truth: List[ScriptLine], speaker_mapping: Dict[str, str], detailed: bool = False | ||
) -> Dict[str, float] | float: | ||
"""Computes the diarization error rate (DER). | ||
Diarizztion error rate is the sum of the false alarms (when speech is detected but none is there), | ||
missed detections (when speech is there but not detected), and speaker confusions (when speech is | ||
attributed to the wrong speaker). For more details see: | ||
https://docs.kolena.com/metrics/diarization-error-rate/ | ||
Args: | ||
inference (List[ScriptLine]): the diarization generated as the result from a model | ||
ground_truth (List[ScriptLine]): annotations that serve as the ground truth diarization | ||
speaker_mapping (Dict[str, str]): mapping between speakers in inference and in ground_truth | ||
detailed (bool): whether to include each component that contributed to the overall diarization error rate | ||
Returns: | ||
Either the diarization error rate (float) or a dictionary containing the diarization error rate and its | ||
individual components. The individual components typically have units of seconds. | ||
""" | ||
inference_interval = pyinterval() | ||
ground_truth_interval = pyinterval() | ||
speaker_inference_intervals = {} | ||
speaker_ground_truth_intervals = {} | ||
|
||
for line in inference: | ||
assert line.speaker | ||
inference_interval = inference_interval | pyinterval[line.start, line.end] | ||
if line.speaker not in speaker_inference_intervals: | ||
speaker_inference_intervals[line.speaker] = pyinterval[line.start, line.end] | ||
else: | ||
tmp = speaker_inference_intervals[line.speaker] | pyinterval[line.start, line.end] | ||
speaker_inference_intervals[line.speaker] = tmp | ||
|
||
for line in ground_truth: | ||
assert line.speaker | ||
ground_truth_interval = ground_truth_interval | pyinterval[line.start, line.end] | ||
if line.speaker not in speaker_ground_truth_intervals: | ||
speaker_ground_truth_intervals[line.speaker] = pyinterval[line.start, line.end] | ||
else: | ||
tmp = speaker_ground_truth_intervals[line.speaker] | pyinterval[line.start, line.end] | ||
speaker_ground_truth_intervals[line.speaker] = tmp | ||
|
||
inference_interval_length = _interval_length(inference_interval) | ||
ground_truth_length = _interval_length(ground_truth_interval) | ||
confusion_rate = _speaker_confusion( | ||
speaker_inference_intervals, | ||
speaker_ground_truth_intervals, | ||
ground_truth_interval, | ||
speaker_mapping=speaker_mapping, | ||
) | ||
false_alarms = _false_alarms(inference_interval, ground_truth_interval, ground_truth_length) | ||
missed_detections = _missed_detection(inference_interval, ground_truth_interval, inference_interval_length) | ||
|
||
der = (confusion_rate + false_alarms + missed_detections) / ground_truth_length | ||
|
||
if detailed: | ||
return { | ||
"false_alarms": false_alarms, | ||
"missed_detections": missed_detections, | ||
"speaker_confusions": confusion_rate, | ||
"der": der, | ||
} | ||
else: | ||
return der | ||
|
||
|
||
def _false_alarms(inference: pyinterval, ground_truth: pyinterval, ground_truth_length: float) -> float: | ||
"""Calculate the amount of false alarms. | ||
Calculates the amount of false alarms comparing the total amount of time the union of each | ||
inference and the ground truth adds. | ||
""" | ||
false_alarms = 0.0 | ||
# print(ground_truth) | ||
for interval in inference.components: | ||
# print(interval) | ||
false_alarms += _interval_length(interval | ground_truth) - ground_truth_length | ||
return false_alarms | ||
|
||
|
||
def _missed_detection(inference: pyinterval, ground_truth: pyinterval, inference_length: float) -> float: | ||
"""Calculate amount of missed detections. | ||
Calculates the amount of missed detections by comparing the total amount of time the union of each | ||
ground truth segment and inferred diariazion adds. | ||
""" | ||
missed_detections = 0.0 | ||
for interval in ground_truth.components: | ||
missed_detections += _interval_length(interval | inference) - inference_length | ||
return missed_detections | ||
|
||
|
||
def _speaker_confusion( | ||
inferred_speaker_intervals: Dict[str, pyinterval], | ||
true_speaker_intervals: Dict[str, pyinterval], | ||
ground_truth: pyinterval, | ||
speaker_mapping: Dict[str, str], | ||
) -> float: | ||
"""Calculate amount of speaker confusion. | ||
Calculates the amount of speaker confusion by testing for each inferred speaker the amount of time | ||
that is added when their inferred speech segments are intersected with their ground truth segments vs. | ||
when they are intersected with the entire ground truth. | ||
""" | ||
confusion = 0.0 | ||
for inferred_speaker, inferred_speaker_interval in inferred_speaker_intervals.items(): | ||
total_overlap = _interval_length(inferred_speaker_interval & ground_truth) | ||
equivalent_true_speaker = speaker_mapping[inferred_speaker] | ||
non_confused_overlap = 0.0 | ||
if equivalent_true_speaker: | ||
ground_truth_speaker_interval = true_speaker_intervals[equivalent_true_speaker] | ||
non_confused_overlap = _interval_length(inferred_speaker_interval & ground_truth_speaker_interval) | ||
confusion += total_overlap - non_confused_overlap | ||
return confusion | ||
|
||
|
||
def _interval_length(interval: pyinterval) -> float: | ||
"""Calculates the length in time that the interval represents.""" | ||
return sum([x[1] - x[0] for x in interval]) |
62 changes: 62 additions & 0 deletions
62
src/tests/audio/tasks/speaker_diarization_evaluation_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""Testing functions for speaker diarization evaluations.""" | ||
|
||
from senselab.audio.tasks.speaker_diarization_evaluation import calculate_diarization_error_rate | ||
from senselab.utils.data_structures.script_line import ScriptLine | ||
|
||
|
||
def test_diarization_error_rate_non_existent_speaker() -> None: | ||
"""Tests speaker diarization error rate when a non-existent speaker is found. | ||
This example diarization can be found here: | ||
https://docs.kolena.com/metrics/diarization-error-rate/#implementation-details | ||
""" | ||
ground_truth = [ | ||
ScriptLine(start=0, end=10, speaker="A"), | ||
ScriptLine(start=13, end=21, speaker="B"), | ||
ScriptLine(start=24, end=32, speaker="A"), | ||
ScriptLine(start=32, end=40, speaker="C"), | ||
] | ||
|
||
inference = [ | ||
ScriptLine(start=2, end=14, speaker="A"), | ||
ScriptLine(start=14, end=15, speaker="C"), | ||
ScriptLine(start=15, end=20, speaker="B"), | ||
ScriptLine(start=23, end=36, speaker="C"), | ||
ScriptLine(start=36, end=40, speaker="D"), | ||
] | ||
|
||
speaker_mapping = {"A": "A", "B": "B", "C": "C", "D": ""} | ||
|
||
diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping, detailed=True) | ||
assert isinstance(diarization, dict) | ||
assert diarization["false_alarms"] == 4 | ||
assert diarization["missed_detections"] == 3 | ||
assert diarization["speaker_confusions"] == 14 | ||
|
||
|
||
def test_diarization_error_rate_undetected_speaker() -> None: | ||
"""Tests speaker diarization error rate when a speaker goes undetected. | ||
This example diarization can be found here: | ||
https://docs.kolena.com/metrics/diarization-error-rate/#example | ||
""" | ||
ground_truth = [ | ||
ScriptLine(start=0, end=5, speaker="C"), | ||
ScriptLine(start=5, end=9, speaker="D"), | ||
ScriptLine(start=10, end=14, speaker="A"), | ||
ScriptLine(start=14, end=15, speaker="D"), | ||
ScriptLine(start=17, end=20, speaker="C"), | ||
ScriptLine(start=22, end=25, speaker="B"), | ||
] | ||
|
||
inference = [ | ||
ScriptLine(start=0, end=8, speaker="C"), | ||
ScriptLine(start=11, end=15, speaker="A"), | ||
ScriptLine(start=17, end=21, speaker="C"), | ||
ScriptLine(start=23, end=25, speaker="B"), | ||
] | ||
|
||
speaker_mapping = {"A": "A", "B": "B", "C": "C"} | ||
|
||
diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping) | ||
assert diarization == 0.4 |