Skip to content

Commit

Permalink
quality_control: refactor confusion matrix initialization
Browse files Browse the repository at this point in the history
The existing code creates a row and column for each category in the dataset,
plus an extra row/column at the end for unmatched annotations. However, it's
not immediately clear what order these rows/columns are in, because all
access is done indirectly through the `confusion_matrix_labels_rmap`
dictionary.

IMO, this is an unnecessary level of indirection, which obscures the
structure of the matrix. It's easier to just assign rows/columns 0..(N-1) to
the categories and N to "unmatched".

Do just that. Additionally, factor out the code creating the matrix into a
function, because it's used in two places.
  • Loading branch information
SpecLad committed Mar 15, 2024
1 parent f513aa1 commit fa334e9
Showing 1 changed file with 32 additions and 46 deletions.
78 changes: 32 additions & 46 deletions cvat/apps/quality_control/quality_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,26 +1874,16 @@ def _find_closest_unmatched_shape(shape: dm.Annotation):
invalid_labels_count = len(mismatches)
total_labels_count = valid_labels_count + invalid_labels_count

confusion_matrix_labels = {
i: label.name
for i, label in enumerate(self._gt_dataset.categories()[dm.AnnotationType.label])
if not label.parent
}
confusion_matrix_labels[None] = "unmatched"
confusion_matrix_labels_rmap = {k: i for i, k in enumerate(confusion_matrix_labels.keys())}
confusion_matrix_label_count = len(confusion_matrix_labels)
confusion_matrix = np.zeros(
(confusion_matrix_label_count, confusion_matrix_label_count), dtype=int
)
confusion_matrix_labels, confusion_matrix = self._make_zero_confusion_matrix()
for gt_ann, ds_ann in itertools.chain(
# fully matched annotations - shape, label, attributes
matches,
mismatches,
zip(itertools.repeat(None), ds_unmatched),
zip(gt_unmatched, itertools.repeat(None)),
):
ds_label_idx = confusion_matrix_labels_rmap[ds_ann.label if ds_ann else ds_ann]
gt_label_idx = confusion_matrix_labels_rmap[gt_ann.label if gt_ann else gt_ann]
ds_label_idx = ds_ann.label if ds_ann else self._UNMATCHED_IDX
gt_label_idx = gt_ann.label if gt_ann else self._UNMATCHED_IDX
confusion_matrix[ds_label_idx, gt_label_idx] += 1

matched_ann_counts = np.diag(confusion_matrix)
Expand All @@ -1906,15 +1896,11 @@ def _find_closest_unmatched_shape(shape: dm.Annotation):
label_recalls = _arr_div(matched_ann_counts, gt_ann_counts)

valid_annotations_count = np.sum(matched_ann_counts)
missing_annotations_count = np.sum(confusion_matrix[confusion_matrix_labels_rmap[None], :])
extra_annotations_count = np.sum(confusion_matrix[:, confusion_matrix_labels_rmap[None]])
missing_annotations_count = np.sum(confusion_matrix[self._UNMATCHED_IDX, :])
extra_annotations_count = np.sum(confusion_matrix[:, self._UNMATCHED_IDX])
total_annotations_count = np.sum(confusion_matrix)
ds_annotations_count = (
np.sum(ds_ann_counts) - ds_ann_counts[confusion_matrix_labels_rmap[None]]
)
gt_annotations_count = (
np.sum(gt_ann_counts) - gt_ann_counts[confusion_matrix_labels_rmap[None]]
)
ds_annotations_count = np.sum(ds_ann_counts[: self._UNMATCHED_IDX])
gt_annotations_count = np.sum(gt_ann_counts[: self._UNMATCHED_IDX])

self._frame_results[frame_id] = ComparisonReportFrameSummary(
annotations=ComparisonReportAnnotationsSummary(
Expand All @@ -1925,7 +1911,7 @@ def _find_closest_unmatched_shape(shape: dm.Annotation):
ds_count=ds_annotations_count,
gt_count=gt_annotations_count,
confusion_matrix=ConfusionMatrix(
labels=list(confusion_matrix_labels.values()),
labels=confusion_matrix_labels,
rows=confusion_matrix,
precision=label_precisions,
recall=label_recalls,
Expand Down Expand Up @@ -1953,6 +1939,22 @@ def _find_closest_unmatched_shape(shape: dm.Annotation):

return conflicts

# row/column index in the confusion matrix corresponding to unmatched annotations
_UNMATCHED_IDX = -1

def _make_zero_confusion_matrix(self) -> Tuple[List[str], np.ndarray]:
label_names = [
label.name
for label in self._gt_dataset.categories()[dm.AnnotationType.label]
if not label.parent
]
label_names.append("unmatched")
num_labels = len(label_names)

confusion_matrix = np.zeros((num_labels, num_labels), dtype=int)

return label_names, confusion_matrix

def generate_report(self) -> ComparisonReport:
self._find_gt_conflicts()

Expand Down Expand Up @@ -1985,7 +1987,8 @@ def generate_report(self) -> ComparisonReport:
),
)
mean_ious = []
confusion_matrices = []
confusion_matrix_labels, confusion_matrix = self._make_zero_confusion_matrix()

for frame_id, frame_result in self._frame_results.items():
intersection_frames.append(frame_id)
conflicts += frame_result.conflicts
Expand All @@ -1994,27 +1997,14 @@ def generate_report(self) -> ComparisonReport:
annotations = deepcopy(frame_result.annotations)
else:
annotations.accumulate(frame_result.annotations)
confusion_matrices.append(frame_result.annotations.confusion_matrix.rows)
confusion_matrix += frame_result.annotations.confusion_matrix.rows

if annotation_components is None:
annotation_components = deepcopy(frame_result.annotation_components)
else:
annotation_components.accumulate(frame_result.annotation_components)
mean_ious.append(frame_result.annotation_components.shape.mean_iou)

confusion_matrix_labels = {
i: label.name
for i, label in enumerate(self._gt_dataset.categories()[dm.AnnotationType.label])
if not label.parent
}
confusion_matrix_labels[None] = "unmatched"
confusion_matrix_labels_rmap = {k: i for i, k in enumerate(confusion_matrix_labels.keys())}
if confusion_matrices:
confusion_matrix = np.sum(confusion_matrices, axis=0)
else:
confusion_matrix = np.zeros(
(len(confusion_matrix_labels), len(confusion_matrix_labels)), dtype=int
)
matched_ann_counts = np.diag(confusion_matrix)
ds_ann_counts = np.sum(confusion_matrix, axis=1)
gt_ann_counts = np.sum(confusion_matrix, axis=0)
Expand All @@ -2025,15 +2015,11 @@ def generate_report(self) -> ComparisonReport:
label_recalls = _arr_div(matched_ann_counts, gt_ann_counts)

valid_annotations_count = np.sum(matched_ann_counts)
missing_annotations_count = np.sum(confusion_matrix[confusion_matrix_labels_rmap[None], :])
extra_annotations_count = np.sum(confusion_matrix[:, confusion_matrix_labels_rmap[None]])
missing_annotations_count = np.sum(confusion_matrix[self._UNMATCHED_IDX, :])
extra_annotations_count = np.sum(confusion_matrix[:, self._UNMATCHED_IDX])
total_annotations_count = np.sum(confusion_matrix)
ds_annotations_count = (
np.sum(ds_ann_counts) - ds_ann_counts[confusion_matrix_labels_rmap[None]]
)
gt_annotations_count = (
np.sum(gt_ann_counts) - gt_ann_counts[confusion_matrix_labels_rmap[None]]
)
ds_annotations_count = np.sum(ds_ann_counts[: self._UNMATCHED_IDX])
gt_annotations_count = np.sum(gt_ann_counts[: self._UNMATCHED_IDX])

return ComparisonReport(
parameters=self.settings,
Expand All @@ -2058,7 +2044,7 @@ def generate_report(self) -> ComparisonReport:
ds_count=ds_annotations_count,
gt_count=gt_annotations_count,
confusion_matrix=ConfusionMatrix(
labels=list(confusion_matrix_labels.values()),
labels=confusion_matrix_labels,
rows=confusion_matrix,
precision=label_precisions,
recall=label_recalls,
Expand Down

0 comments on commit fa334e9

Please sign in to comment.