-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
227 lines (189 loc) · 8.49 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from __future__ import print_function, absolute_import
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
# matplotlib.use('Agg')
import time
from config import *
def init_params(net):
'''Init layer parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=1e-3)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
for inputs, targets in dataloader:
for i in range(3):
mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
def count_parameters(net, all=True):
# If all= Flase, we only return the trainable parameters; tested
return sum(p.numel() for p in net.parameters() if p.requires_grad or all)
def calculate_acc(dataloader, net, device):
with torch.no_grad():
correct = 0
total = 0
for data in dataloader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return (correct / total) * 100
# INPUTS: output have shape of [batch_size, category_count]
# and target in the shape of [batch_size] * there is only one true class for each sample
# topk is tuple of classes to be included in the precision
# topk have to a tuple so if you are giving one number, do not forget the comma
def accuracy(output, target, topk=(1,5)):
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = torch.topk(input=output, k=maxk, dim=1, largest=True, sorted=True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul(100.0/batch_size))
return res
def get_network(model_type, num_experts, backbone, dataset, device, fuse_conv=False, activation='sigmoid'):
# ResNet18 and Related Work
if backbone == 'resnet18':
if model_type == 'cc':
if dataset == 'cifar100':
from cifar.condconv_resnet import CondConv_ResNet18
elif dataset == 'tiny':
from tiny.condconv_resnet import CondConv_ResNet18
net = CondConv_ResNet18(num_experts=num_experts)
elif model_type == 'dycnn':
if dataset == 'cifar100':
from cifar.dycnn_resnet import Dy_ResNet18
elif dataset == 'tiny':
from tiny.dycnn_resnet import Dy_ResNet18
net = Dy_ResNet18(num_experts=num_experts)
elif model_type == 'coconv':
if dataset == 'cifar100':
from cifar.coconv_resnet import CoConv_ResNet18
elif dataset == 'tiny':
from tiny.coconv_resnet import CoConv_ResNet18
net = CoConv_ResNet18(num_experts=num_experts, fuse_conv=fuse_conv, activation=activation)
# AlexNet and Related Work
elif backbone == 'alexnet':
if model_type == 'cc':
if dataset == 'cifar100':
from cifar.condconv_alexnet import CondConv_AlexNet
elif dataset == 'tiny':
from tiny.condconv_alexnet import CondConv_AlexNet
net = CondConv_AlexNet(num_experts=num_experts)
elif model_type == 'dycnn':
if dataset == 'cifar100':
from cifar.dycnn_alexnet import Dy_AlexNet
elif dataset == 'tiny':
from tiny.dycnn_alexnet import Dy_AlexNet
net = Dy_AlexNet(num_experts=num_experts)
elif model_type == 'coconv':
if dataset == 'cifar100':
from cifar.coconv_alexnet import CoConv_AlexNet
elif dataset == 'tiny':
from tiny.coconv_alexnet import CoConv_AlexNet
net = CoConv_AlexNet(num_experts=num_experts, fuse_conv=fuse_conv, activation=activation)
elif backbone == 'mobilenetv2':
# MobileNetV2 and Related Work
if model_type == 'cc':
if dataset == 'cifar100':
from cifar.condconv_mobilenetv2 import CondConv_MobileNetV2
elif dataset == 'tiny':
from tiny.condconv_mobilenetv2 import CondConv_MobileNetV2
net = CondConv_MobileNetV2(num_experts=num_experts)
elif model_type == 'dycnn':
if dataset == 'cifar100':
from cifar.dycnn_mobilenetv2 import Dy_MobileNetV2
elif dataset == 'tiny':
from tiny.dycnn_mobilenetv2 import Dy_MobileNetV2
net = Dy_MobileNetV2(num_experts=num_experts)
elif model_type == 'coconv':
if dataset == 'cifar100':
from cifar.coconv_mobilenetv2 import CoConv_MobileNetV2
elif dataset == 'tiny':
from tiny.coconv_mobilenetv2 import CoConv_MobileNetV2
net = CoConv_MobileNetV2(num_experts=num_experts, fuse_conv=fuse_conv, activation=activation)
else:
print('the network is not supported')
sys.exit()
net = net.to(device)
return net
def get_dataloader(dataset, batch_size):
if dataset == 'cifar100':
train_transform = transforms.Compose(
[transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD)
])
test_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD)
])
trainset = torchvision.datasets.CIFAR100(root=DATA_ROOT, train=True, transform=train_transform, download=True)
testset = torchvision.datasets.CIFAR100(root=DATA_ROOT, train=False, transform=test_transform, download=True)
elif dataset == 'tiny':
train_transform = transforms.Compose(
[transforms.RandomCrop(size=64, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(TINY_IMAGENET_MEAN, TINY_IMAGENET_STD)
])
test_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(TINY_IMAGENET_MEAN, TINY_IMAGENET_STD)
])
trainset = torchvision.datasets.ImageFolder(root=os.path.join(TINY_IMAGENET_DATA_DIR, 'train'), transform=train_transform)
testset = torchvision.datasets.ImageFolder(root=os.path.join(TINY_IMAGENET_DATA_DIR, 'validation'), transform=test_transform)
else:
print('Dataset not supported yet...')
sys.exit()
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
return trainloader, testloader
def save_plot(train_losses, train_accuracy, val_losses, val_accuracy, args, time_stamp):
x = np.array([x for x in range(1, args.epoch + 1)])
y1 = np.array(train_losses)
y2 = np.array(val_losses)
y3 = np.array(train_accuracy)
y4 = np.array(val_accuracy)
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
ax1.plot(x, y1, label='train loss')
ax1.plot(x, y2, label='val loss')
ax1.legend()
ax1.xaxis.set_visible(False)
ax1.set_ylabel('losses')
ax2.plot(x, y3, label='train acc')
ax2.plot(x, y4, label='val acc')
ax2.legend()
ax2.set_xlabel('batches')
ax2.set_ylabel('acc')
plt.savefig('plots/{}-losses-{}-b{}-e{}-{}.png'.format(args.network, args.dataset, args.batch, args.epoch, time_stamp))