-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval_relpose.py
132 lines (102 loc) · 4.73 KB
/
eval_relpose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import os
import numpy as np
import torch
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
from reloc3r.reloc3r_relpose import Reloc3rRelpose, inference_relpose
from reloc3r.datasets import get_data_loader
from reloc3r.utils.metric import *
from reloc3r.utils.device import to_numpy
from tqdm import tqdm
# from pdb import set_trace as bb
def get_args_parser():
parser = argparse.ArgumentParser(description='evaluation code for relative camera pose estimation')
# model
parser.add_argument('--model', type=str,
default='Reloc3rRelpose(img_size=512)')
parser.add_argument('--ckpt', type=str,
default='./checkpoints/Reloc3r-512.pth')
# test set
parser.add_argument('--test_dataset', type=str,
default="ScanNet1500(resolution=(512,384), seed=777)")
parser.add_argument('--batch_size', type=int,
default=1)
parser.add_argument('--num_workers', type=int,
default=10)
# parser.add_argument('--output_dir', type=str,
# default='./output', help='path where to save the pose errors')
return parser
def setup_reloc3r_relpose_model(model, ckpt, device):
print('Building model: {:s}'.format(model))
reloc3r_relpose = eval(model)
reloc3r_relpose.to(device)
print('Loading checkpoint: {:s}'.format(ckpt))
if not os.path.exists(ckpt):
from huggingface_hub import hf_hub_download
print('Downloading checkpoint from HF...')
if '512' in ckpt:
hf_hub_download(repo_id='siyan824/reloc3r-512', filename='Reloc3r-512.pth', local_dir='./checkpoints')
elif '224' in ckpt:
hf_hub_download(repo_id='siyan824/reloc3r-224', filename='Reloc3r-224.pth', local_dir='./checkpoints')
checkpoint = torch.load(ckpt, map_location=device)
reloc3r_relpose.load_state_dict(checkpoint['model'], strict=False)
reloc3r_relpose.eval()
return reloc3r_relpose
def build_dataset(dataset, batch_size, num_workers, test=False):
split = ['Train', 'Test'][test]
print('Building {} data loader for {}'.format(split, dataset))
loader = get_data_loader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_mem=True,
shuffle=not (test),
drop_last=not (test))
print('Dataset length: ', len(loader))
return loader
def test(args):
# if not os.path.exists(args.output_dir):
# os.makedirs(args.output_dir)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
reloc3r_relpose = setup_reloc3r_relpose_model(args.model, args.ckpt, device)
data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True)
for dataset in args.test_dataset.split('+')}
# start evaluation
rerrs, terrs = [], []
for test_name, testset in data_loader_test.items():
print('Testing {:s}'.format(test_name))
with torch.no_grad():
for batch in tqdm(testset):
pose = inference_relpose(batch, reloc3r_relpose, device)
view1, view2 = batch
gt_pose2to1 = torch.inverse(view1['camera_pose']) @ view2['camera_pose']
rerrs_prh = []
terrs_prh = []
# rotation angular err
R_prd = pose[:,0:3,0:3]
for sid in range(len(R_prd)):
rerrs_prh.append(get_rot_err(to_numpy(R_prd[sid]), to_numpy(gt_pose2to1[sid,0:3,0:3])))
# translation direction angular err
t_prd = pose[:,0:3,3]
for sid in range(len(t_prd)):
transl = to_numpy(t_prd[sid])
gt_transl = to_numpy(gt_pose2to1[sid,0:3,-1])
transl_dir = transl / np.linalg.norm(transl)
gt_transl_dir = gt_transl / np.linalg.norm(gt_transl)
terrs_prh.append(get_transl_ang_err(transl_dir, gt_transl_dir))
rerrs += rerrs_prh
terrs += terrs_prh
rerrs = np.array(rerrs)
terrs = np.array(terrs)
print('In total {} pairs'.format(len(rerrs)))
# auc
print(error_auc(rerrs, terrs, thresholds=[5, 10, 20]))
# # save err list to file
# err_list = np.concatenate((rerrs[:,None], terrs[:,None]), axis=-1)
# output_file = '{}/pose_error_list.txt'.format(args.output_dir)
# np.savetxt(output_file, err_list)
# print('Pose errors saved to {}'.format(output_file))
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
test(args)