Skip to content

Commit

Permalink
using inverse distance as matching method for points
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Jan 9, 2025
1 parent c374375 commit c8e4d0f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 447 deletions.
194 changes: 39 additions & 155 deletions docs/examples/baseline_boxes.ipynb

Large diffs are not rendered by default.

393 changes: 108 additions & 285 deletions docs/examples/baseline_points.ipynb

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions milliontrees/common/metrics/all_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class KeypointAccuracy(ElementwiseMetric):
"""Given a specific Intersection over union threshold, determine the
accuracy achieved for a one-class detector."""

def __init__(self, distance_threshold=5, score_threshold=5, name=None, geometry_name="y"):
def __init__(self, distance_threshold=0.1, score_threshold=0.1, name=None, geometry_name="y"):
self.distance_threshold = distance_threshold
self.score_threshold = score_threshold
self.geometry_name = geometry_name
Expand All @@ -478,13 +478,18 @@ def _compute_element_wise(self, y_pred, y_true):

gt_boxes = gt[self.geometry_name]
pred_boxes = target_boxes[target_scores > self.score_threshold]
det_accuracy = torch.mean(torch.stack([ self._accuracy(gt_boxes,pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.51,0.05)]))
det_accuracy = self._accuracy(gt_boxes,pred_boxes,self.distance_threshold)
batch_results.append(det_accuracy)

return torch.tensor(batch_results)

def _point_iou(self, src_keypoints, pred_keypoints):
return torch.cdist(src_keypoints, pred_keypoints, p=2)
def _point_nearness(self, src_keypoints, pred_keypoints):
distance = torch.cdist(src_keypoints, pred_keypoints, p=2)

# Inverson of distance to get relative distance
relative_distance = 1/distance

return relative_distance

def _accuracy(self, src_keypoints, pred_keypoints, distance_threshold):
total_gt = len(src_keypoints)
Expand All @@ -494,7 +499,7 @@ def _accuracy(self, src_keypoints, pred_keypoints, distance_threshold):
matcher = Matcher(distance_threshold,
distance_threshold,
allow_low_quality_matches=False)
match_quality_matrix = self._point_iou(src_keypoints, pred_keypoints)
match_quality_matrix = self._point_nearness(src_keypoints, pred_keypoints)
results = matcher(match_quality_matrix)
true_positive = torch.count_nonzero(results.unique() != -1)
matched_elements = results[results > -1]
Expand Down
6 changes: 4 additions & 2 deletions milliontrees/datasets/TreePoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def __init__(self,
root_dir='data',
download=False,
split_scheme='official',
geometry_name='y'):
geometry_name='y',
distance_threshold=0.1):
self._version = version
self._split_scheme = split_scheme
self.geometry_name = geometry_name
self.distance_threshold = distance_threshold

if self._split_scheme not in ['official', 'random']:
raise ValueError(
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(self,
self._metadata_array = torch.tensor(unique_sources.values.astype('int'))
self._metadata_fields = ['filename_id','source_id']

self._metric = KeypointAccuracy()
self._metric = KeypointAccuracy(distance_threshold=distance_threshold)
self._collate = TreePointsDataset._collate_fn

# eval grouper
Expand Down

0 comments on commit c8e4d0f

Please sign in to comment.