-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_lah.py
138 lines (112 loc) · 5.31 KB
/
compute_lah.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# SPDX-FileCopyrightText: 2025 Idiap Research Institute <[email protected]>
# SPDX-FileContributor: Anshul Gupta <[email protected]>
#
# SPDX-License-Identifier: GPL-3.0-only
import os
import sys
import pandas as pd
from tqdm import tqdm
import numpy as np
from sklearn.metrics import recall_score, precision_score, confusion_matrix
import matplotlib
matplotlib.use('tkagg')
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle
# area of intersection of bbox1 and bbox2 normalized by min(area of bbox 1, area of bbox 2)
def get_inter(bbox1, bbox2):
if bbox1[2]<bbox2[0] or bbox2[2]<bbox1[0] or bbox1[3]<bbox2[1] or bbox2[3]<bbox1[1]:
return 0
area_inter = (np.min([bbox1[2], bbox2[2]]) - np.max([bbox1[0], bbox2[0]])) * \
(np.min([bbox1[3], bbox2[3]]) - np.max([bbox1[1], bbox2[1]]))
area_bbox1 = (bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1])
area_bbox2 = (bbox2[2]-bbox2[0])*(bbox2[3]-bbox2[1])
area_bbox = min(area_bbox1, area_bbox2)
return area_inter/area_bbox
def load_annotations(csv_path):
annotations = pd.read_csv(csv_path)
return annotations
def get_looking_at(dataset, pred_path, subset='full', vis=False):
# Load Annotations
print('loading gt annotations...')
if dataset=='GazeFollow':
bbox_path = '.../LAH_annotations/GazeFollow'
data_path = ".../gazefollow_extended"
gt_path = '.../LAH_annotations/gazefollow_gt_looking_head.csv'
elif dataset=='VideoAtt':
bbox_path = '.../LAH_annotations/images'
data_path = ".../VideoAttentionTarget/images"
gt_path = '.../LAH_annotations/vat_gt_looking_head.csv'
elif dataset=='ChildPlay':
bbox_path = '.../LAH_annotations/ChildPlay-test'
data_path = ".../ChildPlay/images"
print('subset: '+subset)
if subset=='full':
gt_path = '.../LAH_annotations/childplay_gt_looking_head.csv'
elif subset=='child':
gt_path = '.../LAH_annotations/childplay_gt_looking_head_child.csv'
elif subset=='adult':
gt_path = '.../LAH_annotations/childplay_gt_looking_head_adult.csv'
print('loading annotations...')
gt_annotations = load_annotations(gt_path)
pred_annotations = load_annotations(pred_path)
pred_annotations = pred_annotations.groupby(['path', 'pid'])
# check if predicted gaze pt on any head bbox
gt_looking_head = []
pred_looking_head = []
keys = np.arange(len(gt_annotations))
# randomize indices
if vis:
np.random.shuffle(keys)
for key in tqdm(keys, total=len(keys)):
# read gt annotations
gt_row = gt_annotations.iloc[key]
path, pid = gt_row['path'], gt_row['pid']
width, height = gt_row['image_w'], gt_row['image_h']
person_bbox = [gt_row['xmin'], gt_row['ymin'], gt_row['xmax'], gt_row['ymax']]
gt_flag, multi_person = gt_row['looking_head'], gt_row['multi_person']
# read pred annotation
pred_row = pred_annotations.get_group((path, pid)).iloc[0]
pred_gazex, pred_gazey = pred_row['gaze_x']*width, pred_row['gaze_y']*height
# load head bboxes
if dataset=='GazeFollow':
head_bboxes = np.load(os.path.join(bbox_path, path[:-4]+'-head-detections.npy'))
head_bboxes = [bbox for bbox in head_bboxes if bbox[4]>0.4]
else:
head_bboxes = np.load(os.path.join(bbox_path, path[:-4]+'-head-bboxes.npy'))
pred_flag = 0; pred_bbox = []
for bbox in head_bboxes: # iterate over heads
# check if pred gaze on head
if not pred_flag:
if pred_gazex > bbox[0] and pred_gazex < bbox[2] and pred_gazey > bbox[1] and pred_gazey < bbox[3]:
if get_inter(bbox, person_bbox) < 0.8 and get_inter(person_bbox, bbox) < 0.8:
pred_flag = 1
pred_bbox = bbox
pred_looking_head.append(1)
gt_looking_head.append(gt_flag)
if not pred_flag:
pred_looking_head.append(0)
# visualize examples
if vis and gt_flag and pred_flag:
fig, ax = plt.subplots()
img = plt.imread(os.path.join(data_path, path))
ax.imshow(img)
rect1 = Rectangle([gt_row['xmin'], gt_row['ymin']], gt_row['xmax']-gt_row['xmin'], gt_row['ymax']-gt_row['ymin'], edgecolor='g', facecolor='none')
ax.add_patch(rect1)
if pred_flag:
rect2 = Rectangle([pred_bbox[0], pred_bbox[1]], pred_bbox[2]-pred_bbox[0], pred_bbox[3]-pred_bbox[1], edgecolor='r', facecolor='none')
ax.add_patch(rect2)
plt.show()
# compute stats
print(' Stats: ')
tn, fp, fn, tp = confusion_matrix(gt_looking_head, pred_looking_head).ravel()
print('tn, fp, fn, tp: ', tn, fp, fn, tp)
print()
print('Precision: ', precision_score(gt_looking_head, pred_looking_head))
print('Recall: ', recall_score(gt_looking_head, pred_looking_head))
print('False positive rate: ', fp/(tn + fp))
if __name__=='__main__':
dataset='VideoAtt' # {GazeFollow, VideoAtt, ChildPlay}
subset = 'adult' # for ChildPlay, optionally choose from {child, adult}
pred_path = os.path.join('.../output_VideoAtt_dist.csv')
vis = False
get_looking_at(dataset, pred_path, subset, vis)