Skip to content

Commit 01ac8f1

Browse files
authored
Update main.py
There were some conflicts that had to be resolved!
1 parent d0bc64b commit 01ac8f1

File tree

1 file changed

+0
-164
lines changed

1 file changed

+0
-164
lines changed

main.py

Lines changed: 0 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,10 @@ def main():
4242
rgb_read_format = "{:05d}.jpg"
4343
elif args.dataset == 'hmdb51':
4444
num_class = 51
45-
<<<<<<< HEAD
4645
rgb_read_format = "{:05d}.jpg"
4746
elif args.dataset == 'kinetics':
4847
num_class = 400
4948
rgb_read_format = "{:04d}.jpg"
50-
=======
51-
elif args.dataset == 'kinetics':
52-
num_class = 400
53-
rgb_read_format = "{:05d}.jpg"
54-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
5549
elif args.dataset == 'something':
5650
num_class = 174
5751
rgb_read_format = "{:04d}.jpg"
@@ -83,42 +77,11 @@ def main():
8377

8478
print("pretrained_parts: ", args.pretrained_parts)
8579

86-
<<<<<<< HEAD
87-
=======
88-
if args.arch == "ECO":
89-
new_state_dict = init_ECO(model_dict)
90-
if args.arch == "ECOfull":
91-
new_state_dict = init_ECOfull(model_dict)
92-
elif args.arch == "C3DRes18":
93-
new_state_dict = init_C3DRes18(model_dict)
94-
95-
un_init_dict_keys = [k for k in model_dict.keys() if k not in new_state_dict]
96-
print("un_init_dict_keys: ", un_init_dict_keys)
97-
print("\n------------------------------------")
98-
99-
for k in un_init_dict_keys:
100-
new_state_dict[k] = torch.DoubleTensor(model_dict[k].size()).zero_()
101-
if 'weight' in k:
102-
if 'bn' in k:
103-
print("{} init as: 1".format(k))
104-
constant_(new_state_dict[k], 1)
105-
else:
106-
print("{} init as: xavier".format(k))
107-
xavier_uniform_(new_state_dict[k])
108-
elif 'bias' in k:
109-
print("{} init as: 0".format(k))
110-
constant_(new_state_dict[k], 0)
111-
112-
print("------------------------------------")
113-
114-
model.load_state_dict(new_state_dict)
115-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
11680

11781
if args.resume:
11882
if os.path.isfile(args.resume):
11983
print(("=> loading checkpoint '{}'".format(args.resume)))
12084
checkpoint = torch.load(args.resume)
121-
<<<<<<< HEAD
12285
# if not checkpoint['lr']:
12386
if "lr" not in checkpoint.keys():
12487
args.lr = input("No 'lr' attribute found in resume model, please input the 'lr' manually: ")
@@ -163,24 +126,12 @@ def main():
163126
model.load_state_dict(new_state_dict)
164127

165128

166-
=======
167-
args.start_epoch = checkpoint['epoch']
168-
best_prec1 = checkpoint['best_prec1']
169-
model.load_state_dict(checkpoint['state_dict'])
170-
print(("=> loaded checkpoint '{}' (epoch {})"
171-
.format(args.resume, checkpoint['epoch'])))
172-
else:
173-
print(("=> no checkpoint found at '{}'".format(args.resume)))
174-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
175129

176130
cudnn.benchmark = True
177131

178132
# Data loading code
179133
if args.modality != 'RGBDiff':
180-
<<<<<<< HEAD
181134
#input_mean = [0,0,0] #for debugging
182-
=======
183-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
184135
normalize = GroupNormalize(input_mean, input_std)
185136
else:
186137
normalize = IdentityTransform()
@@ -215,11 +166,8 @@ def main():
215166
GroupCenterCrop(crop_size),
216167
Stack(roll=True),
217168
ToTorchFormatTensor(div=False),
218-
<<<<<<< HEAD
219169
#Stack(roll=(args.arch == 'C3DRes18') or (args.arch == 'ECO') or (args.arch == 'ECOfull') or (args.arch == 'ECO_2FC')),
220170
#ToTorchFormatTensor(div=(args.arch != 'C3DRes18') and (args.arch != 'ECO') and (args.arch != 'ECOfull') and (args.arch != 'ECO_2FC')),
221-
=======
222-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
223171
normalize,
224172
])),
225173
batch_size=args.batch_size, shuffle=False,
@@ -244,7 +192,6 @@ def main():
244192
validate(val_loader, model, criterion, 0)
245193
return
246194

247-
<<<<<<< HEAD
248195
saturate_cnt = 0
249196
exp_num = 0
250197

@@ -254,10 +201,6 @@ def main():
254201
saturate_cnt = 0
255202
print("- Learning rate decreases by a factor of '{}'".format(10**(exp_num)))
256203
adjust_learning_rate(optimizer, epoch, args.lr_steps, exp_num)
257-
=======
258-
for epoch in range(args.start_epoch, args.epochs):
259-
adjust_learning_rate(optimizer, epoch, args.lr_steps)
260-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
261204

262205
# train for one epoch
263206
train(train_loader, model, criterion, optimizer, epoch)
@@ -268,25 +211,19 @@ def main():
268211

269212
# remember best prec@1 and save checkpoint
270213
is_best = prec1 > best_prec1
271-
<<<<<<< HEAD
272214
if is_best:
273215
saturate_cnt = 0
274216
else:
275217
saturate_cnt = saturate_cnt + 1
276218

277219
print("- Validation Prec@1 saturates for {} epochs.".format(saturate_cnt))
278-
=======
279-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
280220
best_prec1 = max(prec1, best_prec1)
281221
save_checkpoint({
282222
'epoch': epoch + 1,
283223
'arch': args.arch,
284224
'state_dict': model.state_dict(),
285225
'best_prec1': best_prec1,
286-
<<<<<<< HEAD
287226
'lr': optimizer.param_groups[-1]['lr'],
288-
=======
289-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
290227
}, is_best)
291228

292229
def init_ECO(model_dict):
@@ -299,7 +236,6 @@ def init_ECO(model_dict):
299236

300237
elif args.pretrained_parts == "2D":
301238

302-
<<<<<<< HEAD
303239
if args.net_model2D is not None:
304240
pretrained_dict_2d = torch.load(args.net_model2D)
305241
print(("=> loading model - 2D net: '{}'".format(args.net_model2D)))
@@ -316,25 +252,18 @@ def init_ECO(model_dict):
316252
print("Problem!")
317253
print("k: {}, size: {}".format(k,v.shape))
318254

319-
=======
320-
pretrained_dict_2d = torch.utils.model_zoo.load_url(weight_url_2d)
321-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
322255
new_state_dict = {"module.base_model."+k: v for k, v in pretrained_dict_2d['state_dict'].items() if "module.base_model."+k in model_dict}
323256

324257
elif args.pretrained_parts == "3D":
325258

326259
new_state_dict = {}
327-
<<<<<<< HEAD
328260
if args.net_model3D is not None:
329261
pretrained_dict_3d = torch.load(args.net_model3D)
330262
print(("=> loading model - 3D net: '{}'".format(args.net_model3D)))
331263
else:
332264
pretrained_dict_3d = torch.load("models/C3DResNet18_rgb_16F_kinetics_v1.pth.tar")
333265
print(("=> loading model - 3D net-url: '{}'".format("models/C3DResNet18_rgb_16F_kinetics_v1.pth.tar")))
334266

335-
=======
336-
pretrained_dict_3d = torch.load("models/C3DResNet18_rgb_16F_kinetics_v1.pth.tar")
337-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
338267
for k, v in pretrained_dict_3d['state_dict'].items():
339268
if (k in model_dict) and (v.size() == model_dict[k].size()):
340269
new_state_dict[k] = v
@@ -344,7 +273,6 @@ def init_ECO(model_dict):
344273

345274

346275
elif args.pretrained_parts == "finetune":
347-
<<<<<<< HEAD
348276
print(args.net_modelECO)
349277
print("88"*40)
350278
if args.net_modelECO is not None:
@@ -357,18 +285,12 @@ def init_ECO(model_dict):
357285

358286

359287

360-
=======
361-
362-
print(("=> loading model '{}'".format("models/eco_lite_rgb_16F_kinetics_v2.pth.tar")))
363-
pretrained_dict = torch.load("models/eco_lite_rgb_16F_kinetics_v2.pth.tar")
364-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
365288
new_state_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if (k in model_dict) and (v.size() == model_dict[k].size())}
366289
print("*"*50)
367290
print("Start finetuning ..")
368291

369292
elif args.pretrained_parts == "both":
370293

371-
<<<<<<< HEAD
372294
# Load the 2D net pretrained model
373295
if args.net_model2D is not None:
374296
pretrained_dict_2d = torch.load(args.net_model2D)
@@ -389,21 +311,12 @@ def init_ECO(model_dict):
389311

390312
new_state_dict = {"module.base_model."+k: v for k, v in pretrained_dict_2d['state_dict'].items() if "module.base_model."+k in model_dict}
391313

392-
=======
393-
pretrained_dict_2d = torch.utils.model_zoo.load_url(weight_url_2d)
394-
new_state_dict = {"module.base_model."+k: v for k, v in pretrained_dict_2d['state_dict'].items() if "module.base_model."+k in model_dict}
395-
pretrained_dict_3d = torch.load("models/C3DResNet18_rgb_16F_kinetics_v1.pth.tar")
396-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
397314
for k, v in pretrained_dict_3d['state_dict'].items():
398315
if (k in model_dict) and (v.size() == model_dict[k].size()):
399316
new_state_dict[k] = v
400317

401318
res3a_2_weight_chunk = torch.chunk(pretrained_dict_3d["state_dict"]["module.base_model.res3a_2.weight"], 4, 1)
402319
new_state_dict["module.base_model.res3a_2.weight"] = torch.cat((res3a_2_weight_chunk[0], res3a_2_weight_chunk[1], res3a_2_weight_chunk[2]), 1)
403-
<<<<<<< HEAD
404-
=======
405-
406-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
407320
return new_state_dict
408321

409322
def init_ECOfull(model_dict):
@@ -431,7 +344,6 @@ def init_ECOfull(model_dict):
431344
new_state_dict["module.base_model.res3a_2.weight"] = torch.cat((res3a_2_weight_chunk[0], res3a_2_weight_chunk[1], res3a_2_weight_chunk[2]), 1)
432345

433346

434-
<<<<<<< HEAD
435347

436348
elif args.pretrained_parts == "finetune":
437349
print(args.net_modelECO)
@@ -443,19 +355,12 @@ def init_ECOfull(model_dict):
443355
pretrained_dict = torch.load("models/eco_lite_rgb_16F_kinetics_v2.pth.tar")
444356
print(("=> loading model-finetune-url: '{}'".format("models/eco_lite_rgb_16F_kinetics_v2.pth.tar")))
445357

446-
=======
447-
elif args.pretrained_parts == "finetune":
448-
449-
print(("=> loading model '{}'".format("models/eco_lite_rgb_16F_kinetics_v2.pth.tar")))
450-
pretrained_dict = torch.load("models/eco_lite_rgb_16F_kinetics_v2.pth.tar")
451-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
452358
new_state_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if (k in model_dict) and (v.size() == model_dict[k].size())}
453359
print("*"*50)
454360
print("Start finetuning ..")
455361

456362
elif args.pretrained_parts == "both":
457363

458-
<<<<<<< HEAD
459364
# Load the 2D net pretrained model
460365
if args.net_model2D is not None:
461366
pretrained_dict_2d = torch.load(args.net_model2D)
@@ -476,22 +381,12 @@ def init_ECOfull(model_dict):
476381
print(("=> loading model - 3D net-url: '{}'".format("models/C3DResNet18_rgb_16F_kinetics_v1.pth.tar")))
477382

478383

479-
=======
480-
pretrained_dict_2d = torch.utils.model_zoo.load_url(weight_url_2d)
481-
new_state_dict = {"module.base_model."+k: v for k, v in pretrained_dict_2d['state_dict'].items() if "module.base_model."+k in model_dict}
482-
pretrained_dict_3d = torch.load("models/C3DResNet18_rgb_16F_kinetics_v1.pth.tar")
483-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
484384
for k, v in pretrained_dict_3d['state_dict'].items():
485385
if (k in model_dict) and (v.size() == model_dict[k].size()):
486386
new_state_dict[k] = v
487387

488-
<<<<<<< HEAD
489388
#res3a_2_weight_chunk = torch.chunk(pretrained_dict_3d["state_dict"]["module.base_model.res3a_2.weight"], 4, 1)
490389
#new_state_dict["module.base_model.res3a_2.weight"] = torch.cat((res3a_2_weight_chunk[0], res3a_2_weight_chunk[1], res3a_2_weight_chunk[2]), 1)
491-
=======
492-
res3a_2_weight_chunk = torch.chunk(pretrained_dict_3d["state_dict"]["module.base_model.res3a_2.weight"], 4, 1)
493-
new_state_dict["module.base_model.res3a_2.weight"] = torch.cat((res3a_2_weight_chunk[0], res3a_2_weight_chunk[1], res3a_2_weight_chunk[2]), 1)
494-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
495390

496391
return new_state_dict
497392

@@ -526,18 +421,13 @@ def train(train_loader, model, criterion, optimizer, epoch):
526421
model.train()
527422

528423
end = time.time()
529-
<<<<<<< HEAD
530424

531425
loss_summ = 0
532426
localtime = time.localtime()
533427
end_time = time.strftime("%Y/%m/%d-%H:%M:%S", localtime)
534428
for i, (input, target) in enumerate(train_loader):
535429
# discard final batch
536430

537-
=======
538-
for i, (input, target) in enumerate(train_loader):
539-
# discard final batch
540-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
541431
if i == len(train_loader)-1:
542432
break
543433
# measure data loading time
@@ -549,7 +439,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
549439
target_var = target
550440

551441
# compute output, output size: [batch_size, num_class]
552-
<<<<<<< HEAD
553442

554443
output = model(input_var)
555444

@@ -596,53 +485,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
596485
end_time = time.strftime("%Y/%m/%d-%H:%M:%S", localtime)
597486

598487

599-
=======
600-
output = model(input_var)
601-
602-
loss = criterion(output, target_var)
603-
604-
# measure accuracy and record loss
605-
prec1, prec5 = accuracy(output.data, target, topk=(1,5))
606-
losses.update(loss.item(), input.size(0))
607-
top1.update(prec1.item(), input.size(0))
608-
top5.update(prec5.item(), input.size(0))
609-
610-
611-
# compute gradient and do SGD step
612-
loss.backward()
613-
614-
if i % args.iter_size == 0:
615-
# scale down gradients when iter size is functioning
616-
if args.iter_size != 1:
617-
for g in optimizer.param_groups:
618-
for p in g['params']:
619-
p.grad /= args.iter_size
620-
621-
if args.clip_gradient is not None:
622-
total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
623-
if total_norm > args.clip_gradient:
624-
print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm))
625-
else:
626-
total_norm = 0
627-
628-
optimizer.step()
629-
optimizer.zero_grad()
630-
631-
632-
# measure elapsed time
633-
batch_time.update(time.time() - end)
634-
end = time.time()
635-
636-
if i % args.print_freq == 0:
637-
print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
638-
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
639-
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
640-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
641-
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
642-
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
643-
epoch, i, len(train_loader), batch_time=batch_time,
644-
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'])))
645-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
646488

647489

648490
def validate(val_loader, model, criterion, iter, logger=None):
@@ -722,16 +564,10 @@ def update(self, val, n=1):
722564
self.avg = self.sum / self.count
723565

724566

725-
<<<<<<< HEAD
726567
def adjust_learning_rate(optimizer, epoch, lr_steps, exp_num):
727568
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
728569
# decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
729570
decay = 0.1 ** (exp_num)
730-
=======
731-
def adjust_learning_rate(optimizer, epoch, lr_steps):
732-
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
733-
decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
734-
>>>>>>> 1da05d6e7d9dc0b61b5fd230758ee355c9700f8a
735571
lr = args.lr * decay
736572
decay = args.weight_decay
737573
for param_group in optimizer.param_groups:

0 commit comments

Comments
 (0)