diff --git a/pyproject.toml b/pyproject.toml index 1c235711..3df48c7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ umap-learn = "~=0.5" scikit-learn = "~=1.5" nltk = "~=3.8" vocos = "~=0.1" -pyinterval = "^1.2.0" [tool.poetry.group.dev] optional = true diff --git a/src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py b/src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py index 581d9599..2ff3fc1e 100644 --- a/src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py +++ b/src/senselab/audio/tasks/speaker_diarization_evaluation/utils.py @@ -2,128 +2,54 @@ from typing import Dict, List -from interval import interval as pyinterval +from pyannote.core import Annotation, Segment +from pyannote.metrics.diarization import DiarizationErrorRate, GreedyDiarizationErrorRate from senselab.utils.data_structures.script_line import ScriptLine def calculate_diarization_error_rate( - inference: List[ScriptLine], ground_truth: List[ScriptLine], speaker_mapping: Dict[str, str], detailed: bool = False -) -> Dict[str, float] | float: + hypothesis: List[ScriptLine], + reference: List[ScriptLine], + greedy: bool = False, + return_speaker_mapping: bool = False, + detailed: bool = False, +) -> Dict: """Computes the diarization error rate (DER). - Diarizztion error rate is the sum of the false alarms (when speech is detected but none is there), + Diarizztion error rate is the ratio of the sum of the false alarms (when speech is detected but none is there), missed detections (when speech is there but not detected), and speaker confusions (when speech is - attributed to the wrong speaker). For more details see: + attributed to the wrong speaker) to the total ground truth time spoken. For more details see: https://docs.kolena.com/metrics/diarization-error-rate/ Args: - inference (List[ScriptLine]): the diarization generated as the result from a model - ground_truth (List[ScriptLine]): annotations that serve as the ground truth diarization - speaker_mapping (Dict[str, str]): mapping between speakers in inference and in ground_truth + hypothesis (List[ScriptLine]): the diarization generated as the result from a model + reference (List[ScriptLine]): annotations that serve as the ground truth diarization + greedy (bool): whether to use a greedy speaker mapping vs. one that optimizes for minimizing the confusion + return_speaker_mapping (bool): return the mapping between speakers in the reference and the hypothesis detailed (bool): whether to include each component that contributed to the overall diarization error rate Returns: - Either the diarization error rate (float) or a dictionary containing the diarization error rate and its - individual components. The individual components typically have units of seconds. + A dictionary with at least the diarization error rate, its components if detailed was true, and the + speaker mapping if return_speaker_mapping was given. """ - inference_interval = pyinterval() - ground_truth_interval = pyinterval() - speaker_inference_intervals = {} - speaker_ground_truth_intervals = {} + hypothesis_annotation = Annotation() + reference_annotation = Annotation() - for line in inference: + for line in hypothesis: assert line.speaker - inference_interval = inference_interval | pyinterval[line.start, line.end] - if line.speaker not in speaker_inference_intervals: - speaker_inference_intervals[line.speaker] = pyinterval[line.start, line.end] - else: - tmp = speaker_inference_intervals[line.speaker] | pyinterval[line.start, line.end] - speaker_inference_intervals[line.speaker] = tmp + hypothesis_annotation[Segment(line.start, line.end)] = line.speaker - for line in ground_truth: + for line in reference: assert line.speaker - ground_truth_interval = ground_truth_interval | pyinterval[line.start, line.end] - if line.speaker not in speaker_ground_truth_intervals: - speaker_ground_truth_intervals[line.speaker] = pyinterval[line.start, line.end] - else: - tmp = speaker_ground_truth_intervals[line.speaker] | pyinterval[line.start, line.end] - speaker_ground_truth_intervals[line.speaker] = tmp + reference_annotation[Segment(line.start, line.end)] = line.speaker - inference_interval_length = _interval_length(inference_interval) - ground_truth_length = _interval_length(ground_truth_interval) - confusion_rate = _speaker_confusion( - speaker_inference_intervals, - speaker_ground_truth_intervals, - ground_truth_interval, - speaker_mapping=speaker_mapping, - ) - false_alarms = _false_alarms(inference_interval, ground_truth_interval, ground_truth_length) - missed_detections = _missed_detection(inference_interval, ground_truth_interval, inference_interval_length) + metric = GreedyDiarizationErrorRate() if greedy else DiarizationErrorRate() + der = metric(reference_annotation, hypothesis_annotation, detailed=detailed) + output = {"diarization error rate": der} if not detailed else der + if return_speaker_mapping: + mapping_fn = metric.greedy_mapping if greedy else metric.optimal_mapping + speaker_mapping = mapping_fn(reference_annotation, hypothesis_annotation) + output["speaker_mapping"] = speaker_mapping - der = (confusion_rate + false_alarms + missed_detections) / ground_truth_length - - if detailed: - return { - "false_alarms": false_alarms, - "missed_detections": missed_detections, - "speaker_confusions": confusion_rate, - "der": der, - } - else: - return der - - -def _false_alarms(inference: pyinterval, ground_truth: pyinterval, ground_truth_length: float) -> float: - """Calculate the amount of false alarms. - - Calculates the amount of false alarms comparing the total amount of time the union of each - inference and the ground truth adds. - """ - false_alarms = 0.0 - # print(ground_truth) - for interval in inference.components: - # print(interval) - false_alarms += _interval_length(interval | ground_truth) - ground_truth_length - return false_alarms - - -def _missed_detection(inference: pyinterval, ground_truth: pyinterval, inference_length: float) -> float: - """Calculate amount of missed detections. - - Calculates the amount of missed detections by comparing the total amount of time the union of each - ground truth segment and inferred diariazion adds. - """ - missed_detections = 0.0 - for interval in ground_truth.components: - missed_detections += _interval_length(interval | inference) - inference_length - return missed_detections - - -def _speaker_confusion( - inferred_speaker_intervals: Dict[str, pyinterval], - true_speaker_intervals: Dict[str, pyinterval], - ground_truth: pyinterval, - speaker_mapping: Dict[str, str], -) -> float: - """Calculate amount of speaker confusion. - - Calculates the amount of speaker confusion by testing for each inferred speaker the amount of time - that is added when their inferred speech segments are intersected with their ground truth segments vs. - when they are intersected with the entire ground truth. - """ - confusion = 0.0 - for inferred_speaker, inferred_speaker_interval in inferred_speaker_intervals.items(): - total_overlap = _interval_length(inferred_speaker_interval & ground_truth) - equivalent_true_speaker = speaker_mapping[inferred_speaker] - non_confused_overlap = 0.0 - if equivalent_true_speaker: - ground_truth_speaker_interval = true_speaker_intervals[equivalent_true_speaker] - non_confused_overlap = _interval_length(inferred_speaker_interval & ground_truth_speaker_interval) - confusion += total_overlap - non_confused_overlap - return confusion - - -def _interval_length(interval: pyinterval) -> float: - """Calculates the length in time that the interval represents.""" - return sum([x[1] - x[0] for x in interval]) + return output diff --git a/src/tests/audio/tasks/speaker_diarization_evaluation_test.py b/src/tests/audio/tasks/speaker_diarization_evaluation_test.py index 6d3dea2b..6ca0bd76 100644 --- a/src/tests/audio/tasks/speaker_diarization_evaluation_test.py +++ b/src/tests/audio/tasks/speaker_diarization_evaluation_test.py @@ -25,13 +25,14 @@ def test_diarization_error_rate_non_existent_speaker() -> None: ScriptLine(start=36, end=40, speaker="D"), ] - speaker_mapping = {"A": "A", "B": "B", "C": "C", "D": ""} + speaker_mapping = {"A": "A", "B": "B", "C": "C"} - diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping, detailed=True) + diarization = calculate_diarization_error_rate(inference, ground_truth, return_speaker_mapping=True, detailed=True) assert isinstance(diarization, dict) - assert diarization["false_alarms"] == 4 - assert diarization["missed_detections"] == 3 - assert diarization["speaker_confusions"] == 14 + assert diarization["false alarm"] == 4 + assert diarization["missed detection"] == 3 + assert diarization["confusion"] == 14 + assert diarization["speaker_mapping"] == speaker_mapping def test_diarization_error_rate_undetected_speaker() -> None: @@ -56,7 +57,5 @@ def test_diarization_error_rate_undetected_speaker() -> None: ScriptLine(start=23, end=25, speaker="B"), ] - speaker_mapping = {"A": "A", "B": "B", "C": "C"} - - diarization = calculate_diarization_error_rate(inference, ground_truth, speaker_mapping) - assert diarization == 0.4 + diarization = calculate_diarization_error_rate(inference, ground_truth, greedy=True) + assert diarization["diarization error rate"] == 0.4