-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
94 lines (72 loc) · 2.07 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import shutil
import numpy as np
import torch
import os
import shutil
import numpy as np
import torch
import os
import sys
import time
import logging
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as dst
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def count_parameters_in_MB(model):
return sum(np.prod(v.size()) for name, v in model.named_parameters())/1e6
def create_exp_dir(path):
if not os.path.exists(path):
os.makedirs(path)
print('Experiment dir : {}'.format(path))
def load_pretrained_model(model, pretrained_dict):
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
def transform_time(s):
m, s = divmod(int(s), 60)
h, m = divmod(m, 60)
return h,m,s
def save_checkpoint(state, is_best, save_root):
save_path = os.path.join(save_root, 'checkpoint.pth.tar')
torch.save(state, save_path)
if is_best:
best_save_path = os.path.join(save_root, 'model_best.pth.tar')
shutil.copyfile(save_path, best_save_path)
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res