From fa334e9e7dfc21fc52bd96613f071e86456be5fd Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Fri, 15 Mar 2024 16:00:33 +0200 Subject: [PATCH] quality_control: refactor confusion matrix initialization The existing code creates a row and column for each category in the dataset, plus an extra row/column at the end for unmatched annotations. However, it's not immediately clear what order these rows/columns are in, because all access is done indirectly through the `confusion_matrix_labels_rmap` dictionary. IMO, this is an unnecessary level of indirection, which obscures the structure of the matrix. It's easier to just assign rows/columns 0..(N-1) to the categories and N to "unmatched". Do just that. Additionally, factor out the code creating the matrix into a function, because it's used in two places. --- cvat/apps/quality_control/quality_reports.py | 78 ++++++++------------ 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index 0c30dffc783..3c2585e3601 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -1874,17 +1874,7 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): invalid_labels_count = len(mismatches) total_labels_count = valid_labels_count + invalid_labels_count - confusion_matrix_labels = { - i: label.name - for i, label in enumerate(self._gt_dataset.categories()[dm.AnnotationType.label]) - if not label.parent - } - confusion_matrix_labels[None] = "unmatched" - confusion_matrix_labels_rmap = {k: i for i, k in enumerate(confusion_matrix_labels.keys())} - confusion_matrix_label_count = len(confusion_matrix_labels) - confusion_matrix = np.zeros( - (confusion_matrix_label_count, confusion_matrix_label_count), dtype=int - ) + confusion_matrix_labels, confusion_matrix = self._make_zero_confusion_matrix() for gt_ann, ds_ann in itertools.chain( # fully matched annotations - shape, label, attributes matches, @@ -1892,8 +1882,8 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): zip(itertools.repeat(None), ds_unmatched), zip(gt_unmatched, itertools.repeat(None)), ): - ds_label_idx = confusion_matrix_labels_rmap[ds_ann.label if ds_ann else ds_ann] - gt_label_idx = confusion_matrix_labels_rmap[gt_ann.label if gt_ann else gt_ann] + ds_label_idx = ds_ann.label if ds_ann else self._UNMATCHED_IDX + gt_label_idx = gt_ann.label if gt_ann else self._UNMATCHED_IDX confusion_matrix[ds_label_idx, gt_label_idx] += 1 matched_ann_counts = np.diag(confusion_matrix) @@ -1906,15 +1896,11 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): label_recalls = _arr_div(matched_ann_counts, gt_ann_counts) valid_annotations_count = np.sum(matched_ann_counts) - missing_annotations_count = np.sum(confusion_matrix[confusion_matrix_labels_rmap[None], :]) - extra_annotations_count = np.sum(confusion_matrix[:, confusion_matrix_labels_rmap[None]]) + missing_annotations_count = np.sum(confusion_matrix[self._UNMATCHED_IDX, :]) + extra_annotations_count = np.sum(confusion_matrix[:, self._UNMATCHED_IDX]) total_annotations_count = np.sum(confusion_matrix) - ds_annotations_count = ( - np.sum(ds_ann_counts) - ds_ann_counts[confusion_matrix_labels_rmap[None]] - ) - gt_annotations_count = ( - np.sum(gt_ann_counts) - gt_ann_counts[confusion_matrix_labels_rmap[None]] - ) + ds_annotations_count = np.sum(ds_ann_counts[: self._UNMATCHED_IDX]) + gt_annotations_count = np.sum(gt_ann_counts[: self._UNMATCHED_IDX]) self._frame_results[frame_id] = ComparisonReportFrameSummary( annotations=ComparisonReportAnnotationsSummary( @@ -1925,7 +1911,7 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): ds_count=ds_annotations_count, gt_count=gt_annotations_count, confusion_matrix=ConfusionMatrix( - labels=list(confusion_matrix_labels.values()), + labels=confusion_matrix_labels, rows=confusion_matrix, precision=label_precisions, recall=label_recalls, @@ -1953,6 +1939,22 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): return conflicts + # row/column index in the confusion matrix corresponding to unmatched annotations + _UNMATCHED_IDX = -1 + + def _make_zero_confusion_matrix(self) -> Tuple[List[str], np.ndarray]: + label_names = [ + label.name + for label in self._gt_dataset.categories()[dm.AnnotationType.label] + if not label.parent + ] + label_names.append("unmatched") + num_labels = len(label_names) + + confusion_matrix = np.zeros((num_labels, num_labels), dtype=int) + + return label_names, confusion_matrix + def generate_report(self) -> ComparisonReport: self._find_gt_conflicts() @@ -1985,7 +1987,8 @@ def generate_report(self) -> ComparisonReport: ), ) mean_ious = [] - confusion_matrices = [] + confusion_matrix_labels, confusion_matrix = self._make_zero_confusion_matrix() + for frame_id, frame_result in self._frame_results.items(): intersection_frames.append(frame_id) conflicts += frame_result.conflicts @@ -1994,7 +1997,7 @@ def generate_report(self) -> ComparisonReport: annotations = deepcopy(frame_result.annotations) else: annotations.accumulate(frame_result.annotations) - confusion_matrices.append(frame_result.annotations.confusion_matrix.rows) + confusion_matrix += frame_result.annotations.confusion_matrix.rows if annotation_components is None: annotation_components = deepcopy(frame_result.annotation_components) @@ -2002,19 +2005,6 @@ def generate_report(self) -> ComparisonReport: annotation_components.accumulate(frame_result.annotation_components) mean_ious.append(frame_result.annotation_components.shape.mean_iou) - confusion_matrix_labels = { - i: label.name - for i, label in enumerate(self._gt_dataset.categories()[dm.AnnotationType.label]) - if not label.parent - } - confusion_matrix_labels[None] = "unmatched" - confusion_matrix_labels_rmap = {k: i for i, k in enumerate(confusion_matrix_labels.keys())} - if confusion_matrices: - confusion_matrix = np.sum(confusion_matrices, axis=0) - else: - confusion_matrix = np.zeros( - (len(confusion_matrix_labels), len(confusion_matrix_labels)), dtype=int - ) matched_ann_counts = np.diag(confusion_matrix) ds_ann_counts = np.sum(confusion_matrix, axis=1) gt_ann_counts = np.sum(confusion_matrix, axis=0) @@ -2025,15 +2015,11 @@ def generate_report(self) -> ComparisonReport: label_recalls = _arr_div(matched_ann_counts, gt_ann_counts) valid_annotations_count = np.sum(matched_ann_counts) - missing_annotations_count = np.sum(confusion_matrix[confusion_matrix_labels_rmap[None], :]) - extra_annotations_count = np.sum(confusion_matrix[:, confusion_matrix_labels_rmap[None]]) + missing_annotations_count = np.sum(confusion_matrix[self._UNMATCHED_IDX, :]) + extra_annotations_count = np.sum(confusion_matrix[:, self._UNMATCHED_IDX]) total_annotations_count = np.sum(confusion_matrix) - ds_annotations_count = ( - np.sum(ds_ann_counts) - ds_ann_counts[confusion_matrix_labels_rmap[None]] - ) - gt_annotations_count = ( - np.sum(gt_ann_counts) - gt_ann_counts[confusion_matrix_labels_rmap[None]] - ) + ds_annotations_count = np.sum(ds_ann_counts[: self._UNMATCHED_IDX]) + gt_annotations_count = np.sum(gt_ann_counts[: self._UNMATCHED_IDX]) return ComparisonReport( parameters=self.settings, @@ -2058,7 +2044,7 @@ def generate_report(self) -> ComparisonReport: ds_count=ds_annotations_count, gt_count=gt_annotations_count, confusion_matrix=ConfusionMatrix( - labels=list(confusion_matrix_labels.values()), + labels=confusion_matrix_labels, rows=confusion_matrix, precision=label_precisions, recall=label_recalls,