-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
202 lines (166 loc) · 7.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import argparse
import os
import random
import numpy as np
import torch
import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import datasets
import losses
import models
from utils import batch_to_device, batch_errors, batch_compute_utils, log_poses, log_errors
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'path', metavar='DATA_PATH',
help='path to the dataset directory, e.g. "/home/data/KingsCollege"'
)
parser.add_argument(
'--loss', help='loss function for training',
choices=['local_homography', 'global_homography', 'posenet', 'homoscedastic', 'geometric', 'dsac'],
default='local_homography'
)
parser.add_argument('--epochs', help='number of epochs for training', type=int, default=5000)
parser.add_argument('--batch_size', help='training batch size', type=int, default=64)
parser.add_argument('--xmin_percentile', help='xmin depth percentile', type=float, default=0.025)
parser.add_argument('--xmax_percentile', help='xmax depth percentile', type=float, default=0.975)
parser.add_argument(
'--weights', metavar='WEIGHTS_PATH',
help='path to weights with which the model will be initialized'
)
parser.add_argument(
'--device', default='cpu',
help='set the device to train the model, `cuda` for GPU'
)
args = parser.parse_args()
# Set seed for reproductibility
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Load model
model = models.load_model(args.weights)
model.train()
model.to(args.device)
# Load dataset
dataset_name = os.path.basename(os.path.normpath(args.path))
if dataset_name in ['GreatCourt', 'KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch', 'Street']:
dataset = datasets.CambridgeDataset(args.path, args.xmin_percentile, args.xmax_percentile)
elif dataset_name in ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']:
dataset = datasets.SevenScenesDataset(args.path, args.xmin_percentile, args.xmax_percentile)
else:
dataset = datasets.COLMAPDataset(args.path, args.xmin_percentile, args.xmax_percentile)
# Wrapper for use with PyTorch's DataLoader
train_dataset = datasets.RelocDataset(dataset.train_data)
test_dataset = datasets.RelocDataset(dataset.test_data)
# Creating data loaders for train and test data
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
collate_fn=datasets.collate_fn,
drop_last=True
)
test_loader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
collate_fn=datasets.collate_fn
)
# Adam optimizer default epsilon parameter is 1e-8
eps = 1e-8
# Instantiate loss
if args.loss == 'local_homography':
criterion = losses.LocalHomographyLoss(device=args.device)
eps = 1e-14 # Adam optimizer epsilon is set to 1e-14 for homography losses
elif args.loss == 'global_homography':
criterion = losses.GlobalHomographyLoss(
xmin=dataset.train_global_xmin,
xmax=dataset.train_global_xmax,
device=args.device
)
eps = 1e-14 # Adam optimizer epsilon is set to 1e-14 for homography losses
elif args.loss == 'posenet':
criterion = losses.PoseNetLoss(beta=500)
elif args.loss == 'homoscedastic':
criterion = losses.HomoscedasticLoss(s_hat_t=0.0, s_hat_q=-3.0, device=args.device)
elif args.loss == 'geometric':
criterion = losses.GeometricLoss()
elif args.loss == 'dsac':
criterion = losses.DSACLoss()
else:
raise Exception(f'Loss {args.loss} not recognized...')
# Instantiate adam optimizer
optimizer = torch.optim.Adam(list(model.parameters()) + list(criterion.parameters()), lr=1e-4, eps=eps)
# Set up tensorboard
writer = SummaryWriter(os.path.join('logs', os.path.basename(os.path.normpath(args.path)), args.loss))
# Set up folder to save weights
if not os.path.exists(os.path.join(writer.log_dir, 'weights')):
os.makedirs(os.path.join(writer.log_dir, 'weights'))
# Set up file to save logs
log_file_path = os.path.join(writer.log_dir, 'epochs_poses_log.csv')
with open(log_file_path, mode='w') as log_file:
log_file.write('epoch,image_file,type,w_tx_chat,w_ty_chat,w_tz_chat,chat_qw_w,chat_qx_w,chat_qy_w,chat_qz_w\n')
print('Start training...')
for epoch in tqdm.tqdm(range(args.epochs)):
epoch_loss = 0
errors = {}
for batch in train_loader:
optimizer.zero_grad()
# Move all batch data to proper device
batch = batch_to_device(batch, args.device)
# Estimate the pose from the image
batch['w_t_chat'], batch['chat_q_w'] = model(batch['image']).split([3, 4], dim=1)
# Computes useful data for our batch
# - Normalized quaternion
# - Rotation matrix from this normalized quaternion
# - Reshapes translation component to fit shape (batch_size, 3, 1)
batch_compute_utils(batch)
# Compute loss
loss = criterion(batch)
# Backprop
loss.backward()
optimizer.step()
# Add current batch loss to epoch loss
epoch_loss += loss.item() / len(train_loader)
# Compute training batch errors and log poses
with torch.no_grad():
batch_errors(batch, errors)
with open(log_file_path, mode='a') as log_file:
log_poses(log_file, batch, epoch, 'train')
# Log epoch loss
writer.add_scalar('train loss', epoch_loss, epoch)
with torch.no_grad():
# Log train errors
log_errors(errors, writer, epoch, 'train')
# Set the model to eval mode for test data
model.eval()
errors = {}
for batch in test_loader:
# Compute test poses estimations
batch = batch_to_device(batch, args.device)
batch['w_t_chat'], batch['chat_q_w'] = model(batch['image']).split([3, 4], dim=1)
batch_compute_utils(batch)
# Log test poses
with open(log_file_path, mode='a') as log_file:
log_poses(log_file, batch, epoch, 'test')
# Compute test errors
batch_errors(batch, errors)
# Log test errors
log_errors(errors, writer, epoch, 'test')
# Log loss parameters, if there are any
for p_name, p in criterion.named_parameters():
writer.add_scalar(p_name, p, epoch)
writer.flush()
model.train()
# Save model and optimizer weights every n and last epochs:
if epoch % 500 == 0 or epoch == args.epochs - 1:
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'criterion_state_dict': criterion.state_dict()
}, os.path.join(writer.log_dir, 'weights', f'epoch_{epoch}.pth'))
writer.close()