Skip to content

Commit

Permalink
Reimplement diarization error rate using pyannote
Browse files Browse the repository at this point in the history
  • Loading branch information
wilke0818 committed Sep 9, 2024
1 parent 0da6a99 commit 34ed72a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 114 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ umap-learn = "~=0.5"
scikit-learn = "~=1.5"
nltk = "~=3.8"
vocos = "~=0.1"
pyinterval = "^1.2.0"

[tool.poetry.group.dev]
optional = true
Expand Down
134 changes: 30 additions & 104 deletions src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,128 +2,54 @@

from typing import Dict, List

from interval import interval as pyinterval
from pyannote.core import Annotation, Segment
from pyannote.metrics.diarization import DiarizationErrorRate, GreedyDiarizationErrorRate

from senselab.utils.data_structures.script_line import ScriptLine


def calculate_diarization_error_rate(
inference: List[ScriptLine], ground_truth: List[ScriptLine], speaker_mapping: Dict[str, str], detailed: bool = False
) -> Dict[str, float] | float:
hypothesis: List[ScriptLine],
reference: List[ScriptLine],
greedy: bool = False,
return_speaker_mapping: bool = False,
detailed: bool = False,
) -> Dict:
"""Computes the diarization error rate (DER).
Diarizztion error rate is the sum of the false alarms (when speech is detected but none is there),
Diarizztion error rate is the ratio of the sum of the false alarms (when speech is detected but none is there),
missed detections (when speech is there but not detected), and speaker confusions (when speech is
attributed to the wrong speaker). For more details see:
attributed to the wrong speaker) to the total ground truth time spoken. For more details see:
https://docs.kolena.com/metrics/diarization-error-rate/
Args:
inference (List[ScriptLine]): the diarization generated as the result from a model
ground_truth (List[ScriptLine]): annotations that serve as the ground truth diarization
speaker_mapping (Dict[str, str]): mapping between speakers in inference and in ground_truth
hypothesis (List[ScriptLine]): the diarization generated as the result from a model
reference (List[ScriptLine]): annotations that serve as the ground truth diarization
greedy (bool): whether to use a greedy speaker mapping vs. one that optimizes for minimizing the confusion
return_speaker_mapping (bool): return the mapping between speakers in the reference and the hypothesis
detailed (bool): whether to include each component that contributed to the overall diarization error rate
Returns:
Either the diarization error rate (float) or a dictionary containing the diarization error rate and its
individual components. The individual components typically have units of seconds.
A dictionary with at least the diarization error rate, its components if detailed was true, and the
speaker mapping if return_speaker_mapping was given.
"""
inference_interval = pyinterval()
ground_truth_interval = pyinterval()
speaker_inference_intervals = {}
speaker_ground_truth_intervals = {}
hypothesis_annotation = Annotation()
reference_annotation = Annotation()

for line in inference:
for line in hypothesis:
assert line.speaker
inference_interval = inference_interval | pyinterval[line.start, line.end]
if line.speaker not in speaker_inference_intervals:
speaker_inference_intervals[line.speaker] = pyinterval[line.start, line.end]
else:
tmp = speaker_inference_intervals[line.speaker] | pyinterval[line.start, line.end]
speaker_inference_intervals[line.speaker] = tmp
hypothesis_annotation[Segment(line.start, line.end)] = line.speaker

for line in ground_truth:
for line in reference:
assert line.speaker
ground_truth_interval = ground_truth_interval | pyinterval[line.start, line.end]
if line.speaker not in speaker_ground_truth_intervals:
speaker_ground_truth_intervals[line.speaker] = pyinterval[line.start, line.end]
else:
tmp = speaker_ground_truth_intervals[line.speaker] | pyinterval[line.start, line.end]
speaker_ground_truth_intervals[line.speaker] = tmp
reference_annotation[Segment(line.start, line.end)] = line.speaker

inference_interval_length = _interval_length(inference_interval)
ground_truth_length = _interval_length(ground_truth_interval)
confusion_rate = _speaker_confusion(
speaker_inference_intervals,
speaker_ground_truth_intervals,
ground_truth_interval,
speaker_mapping=speaker_mapping,
)
false_alarms = _false_alarms(inference_interval, ground_truth_interval, ground_truth_length)
missed_detections = _missed_detection(inference_interval, ground_truth_interval, inference_interval_length)
metric = GreedyDiarizationErrorRate() if greedy else DiarizationErrorRate()
der = metric(reference_annotation, hypothesis_annotation, detailed=detailed)
output = {"diarization error rate": der} if not detailed else der
if return_speaker_mapping:
mapping_fn = metric.greedy_mapping if greedy else metric.optimal_mapping
speaker_mapping = mapping_fn(reference_annotation, hypothesis_annotation)
output["speaker_mapping"] = speaker_mapping

der = (confusion_rate + false_alarms + missed_detections) / ground_truth_length

if detailed:
return {
"false_alarms": false_alarms,
"missed_detections": missed_detections,
"speaker_confusions": confusion_rate,
"der": der,
}
else:
return der


def _false_alarms(inference: pyinterval, ground_truth: pyinterval, ground_truth_length: float) -> float:
"""Calculate the amount of false alarms.
Calculates the amount of false alarms comparing the total amount of time the union of each
inference and the ground truth adds.
"""
false_alarms = 0.0
# print(ground_truth)
for interval in inference.components:
# print(interval)
false_alarms += _interval_length(interval | ground_truth) - ground_truth_length
return false_alarms


def _missed_detection(inference: pyinterval, ground_truth: pyinterval, inference_length: float) -> float:
"""Calculate amount of missed detections.
Calculates the amount of missed detections by comparing the total amount of time the union of each
ground truth segment and inferred diariazion adds.
"""
missed_detections = 0.0
for interval in ground_truth.components:
missed_detections += _interval_length(interval | inference) - inference_length
return missed_detections


def _speaker_confusion(
inferred_speaker_intervals: Dict[str, pyinterval],
true_speaker_intervals: Dict[str, pyinterval],
ground_truth: pyinterval,
speaker_mapping: Dict[str, str],
) -> float:
"""Calculate amount of speaker confusion.
Calculates the amount of speaker confusion by testing for each inferred speaker the amount of time
that is added when their inferred speech segments are intersected with their ground truth segments vs.
when they are intersected with the entire ground truth.
"""
confusion = 0.0
for inferred_speaker, inferred_speaker_interval in inferred_speaker_intervals.items():
total_overlap = _interval_length(inferred_speaker_interval & ground_truth)
equivalent_true_speaker = speaker_mapping[inferred_speaker]
non_confused_overlap = 0.0
if equivalent_true_speaker:
ground_truth_speaker_interval = true_speaker_intervals[equivalent_true_speaker]
non_confused_overlap = _interval_length(inferred_speaker_interval & ground_truth_speaker_interval)
confusion += total_overlap - non_confused_overlap
return confusion


def _interval_length(interval: pyinterval) -> float:
"""Calculates the length in time that the interval represents."""
return sum([x[1] - x[0] for x in interval])
return output
17 changes: 8 additions & 9 deletions src/tests/audio/tasks/speaker_diarization_evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def test_diarization_error_rate_non_existent_speaker() -> None:
ScriptLine(start=36, end=40, speaker="D"),
]

speaker_mapping = {"A": "A", "B": "B", "C": "C", "D": ""}
speaker_mapping = {"A": "A", "B": "B", "C": "C"}

diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping, detailed=True)
diarization = calculate_diarization_error_rate(inference, ground_truth, return_speaker_mapping=True, detailed=True)
assert isinstance(diarization, dict)
assert diarization["false_alarms"] == 4
assert diarization["missed_detections"] == 3
assert diarization["speaker_confusions"] == 14
assert diarization["false alarm"] == 4
assert diarization["missed detection"] == 3
assert diarization["confusion"] == 14
assert diarization["speaker_mapping"] == speaker_mapping


def test_diarization_error_rate_undetected_speaker() -> None:
Expand All @@ -56,7 +57,5 @@ def test_diarization_error_rate_undetected_speaker() -> None:
ScriptLine(start=23, end=25, speaker="B"),
]

speaker_mapping = {"A": "A", "B": "B", "C": "C"}

diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping)
assert diarization == 0.4
diarization = calculate_diarization_error_rate(inference, ground_truth, greedy=True)
assert diarization["diarization error rate"] == 0.4

0 comments on commit 34ed72a

Please sign in to comment.