Skip to content

Commit

Permalink
Implement diaraization error rate
Browse files Browse the repository at this point in the history
  • Loading branch information
wilke0818 committed Sep 6, 2024
1 parent a4c4560 commit 0da6a99
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ umap-learn = "~=0.5"
scikit-learn = "~=1.5"
nltk = "~=3.8"
vocos = "~=0.1"
pyinterval = "^1.2.0"

[tool.poetry.group.dev]
optional = true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Top level file exposing speaker diarization utility functions."""

from .utils import calculate_diarization_error_rate # noqa: F401
129 changes: 129 additions & 0 deletions src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Defines utility functions for evaluating speaker diarization results."""

from typing import Dict, List

from interval import interval as pyinterval

from senselab.utils.data_structures.script_line import ScriptLine


def calculate_diarization_error_rate(
inference: List[ScriptLine], ground_truth: List[ScriptLine], speaker_mapping: Dict[str, str], detailed: bool = False
) -> Dict[str, float] | float:
"""Computes the diarization error rate (DER).
Diarizztion error rate is the sum of the false alarms (when speech is detected but none is there),
missed detections (when speech is there but not detected), and speaker confusions (when speech is
attributed to the wrong speaker). For more details see:
https://docs.kolena.com/metrics/diarization-error-rate/
Args:
inference (List[ScriptLine]): the diarization generated as the result from a model
ground_truth (List[ScriptLine]): annotations that serve as the ground truth diarization
speaker_mapping (Dict[str, str]): mapping between speakers in inference and in ground_truth
detailed (bool): whether to include each component that contributed to the overall diarization error rate
Returns:
Either the diarization error rate (float) or a dictionary containing the diarization error rate and its
individual components. The individual components typically have units of seconds.
"""
inference_interval = pyinterval()
ground_truth_interval = pyinterval()
speaker_inference_intervals = {}
speaker_ground_truth_intervals = {}

for line in inference:
assert line.speaker
inference_interval = inference_interval | pyinterval[line.start, line.end]
if line.speaker not in speaker_inference_intervals:
speaker_inference_intervals[line.speaker] = pyinterval[line.start, line.end]
else:
tmp = speaker_inference_intervals[line.speaker] | pyinterval[line.start, line.end]
speaker_inference_intervals[line.speaker] = tmp

for line in ground_truth:
assert line.speaker
ground_truth_interval = ground_truth_interval | pyinterval[line.start, line.end]
if line.speaker not in speaker_ground_truth_intervals:
speaker_ground_truth_intervals[line.speaker] = pyinterval[line.start, line.end]
else:
tmp = speaker_ground_truth_intervals[line.speaker] | pyinterval[line.start, line.end]
speaker_ground_truth_intervals[line.speaker] = tmp

inference_interval_length = _interval_length(inference_interval)
ground_truth_length = _interval_length(ground_truth_interval)
confusion_rate = _speaker_confusion(
speaker_inference_intervals,
speaker_ground_truth_intervals,
ground_truth_interval,
speaker_mapping=speaker_mapping,
)
false_alarms = _false_alarms(inference_interval, ground_truth_interval, ground_truth_length)
missed_detections = _missed_detection(inference_interval, ground_truth_interval, inference_interval_length)

der = (confusion_rate + false_alarms + missed_detections) / ground_truth_length

if detailed:
return {
"false_alarms": false_alarms,
"missed_detections": missed_detections,
"speaker_confusions": confusion_rate,
"der": der,
}
else:
return der


def _false_alarms(inference: pyinterval, ground_truth: pyinterval, ground_truth_length: float) -> float:
"""Calculate the amount of false alarms.
Calculates the amount of false alarms comparing the total amount of time the union of each
inference and the ground truth adds.
"""
false_alarms = 0.0
# print(ground_truth)
for interval in inference.components:
# print(interval)
false_alarms += _interval_length(interval | ground_truth) - ground_truth_length
return false_alarms


def _missed_detection(inference: pyinterval, ground_truth: pyinterval, inference_length: float) -> float:
"""Calculate amount of missed detections.
Calculates the amount of missed detections by comparing the total amount of time the union of each
ground truth segment and inferred diariazion adds.
"""
missed_detections = 0.0
for interval in ground_truth.components:
missed_detections += _interval_length(interval | inference) - inference_length
return missed_detections


def _speaker_confusion(
inferred_speaker_intervals: Dict[str, pyinterval],
true_speaker_intervals: Dict[str, pyinterval],
ground_truth: pyinterval,
speaker_mapping: Dict[str, str],
) -> float:
"""Calculate amount of speaker confusion.
Calculates the amount of speaker confusion by testing for each inferred speaker the amount of time
that is added when their inferred speech segments are intersected with their ground truth segments vs.
when they are intersected with the entire ground truth.
"""
confusion = 0.0
for inferred_speaker, inferred_speaker_interval in inferred_speaker_intervals.items():
total_overlap = _interval_length(inferred_speaker_interval & ground_truth)
equivalent_true_speaker = speaker_mapping[inferred_speaker]
non_confused_overlap = 0.0
if equivalent_true_speaker:
ground_truth_speaker_interval = true_speaker_intervals[equivalent_true_speaker]
non_confused_overlap = _interval_length(inferred_speaker_interval & ground_truth_speaker_interval)
confusion += total_overlap - non_confused_overlap
return confusion


def _interval_length(interval: pyinterval) -> float:
"""Calculates the length in time that the interval represents."""
return sum([x[1] - x[0] for x in interval])
62 changes: 62 additions & 0 deletions src/tests/audio/tasks/speaker_diarization_evaluation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Testing functions for speaker diarization evaluations."""

from senselab.audio.tasks.speaker_diarization_evaluation import calculate_diarization_error_rate
from senselab.utils.data_structures.script_line import ScriptLine


def test_diarization_error_rate_non_existent_speaker() -> None:
"""Tests speaker diarization error rate when a non-existent speaker is found.
This example diarization can be found here:
https://docs.kolena.com/metrics/diarization-error-rate/#implementation-details
"""
ground_truth = [
ScriptLine(start=0, end=10, speaker="A"),
ScriptLine(start=13, end=21, speaker="B"),
ScriptLine(start=24, end=32, speaker="A"),
ScriptLine(start=32, end=40, speaker="C"),
]

inference = [
ScriptLine(start=2, end=14, speaker="A"),
ScriptLine(start=14, end=15, speaker="C"),
ScriptLine(start=15, end=20, speaker="B"),
ScriptLine(start=23, end=36, speaker="C"),
ScriptLine(start=36, end=40, speaker="D"),
]

speaker_mapping = {"A": "A", "B": "B", "C": "C", "D": ""}

diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping, detailed=True)
assert isinstance(diarization, dict)
assert diarization["false_alarms"] == 4
assert diarization["missed_detections"] == 3
assert diarization["speaker_confusions"] == 14


def test_diarization_error_rate_undetected_speaker() -> None:
"""Tests speaker diarization error rate when a speaker goes undetected.
This example diarization can be found here:
https://docs.kolena.com/metrics/diarization-error-rate/#example
"""
ground_truth = [
ScriptLine(start=0, end=5, speaker="C"),
ScriptLine(start=5, end=9, speaker="D"),
ScriptLine(start=10, end=14, speaker="A"),
ScriptLine(start=14, end=15, speaker="D"),
ScriptLine(start=17, end=20, speaker="C"),
ScriptLine(start=22, end=25, speaker="B"),
]

inference = [
ScriptLine(start=0, end=8, speaker="C"),
ScriptLine(start=11, end=15, speaker="A"),
ScriptLine(start=17, end=21, speaker="C"),
ScriptLine(start=23, end=25, speaker="B"),
]

speaker_mapping = {"A": "A", "B": "B", "C": "C"}

diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping)
assert diarization == 0.4

0 comments on commit 0da6a99

Please sign in to comment.