-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathexp0_baseline_models.py
70 lines (53 loc) · 2.26 KB
/
exp0_baseline_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch.nn as nn
import torch.optim as optim
from nn.retrieval_evaluation import evaluate_model_retrieval
from exp_cifar.cifar_dataset import cifar10_loader
from models.cifar_tiny import Cifar_Tiny
from models.resnet import ResNet18
from nn.nn_utils import train_model, save_model
def train_cifar10_model(net, learning_rates=[0.001, 0.0001], iters=[50, 50], output_path='resnet18_cifar10.model'):
"""
Trains a baseline (classification model)
:param net: the network to be trained
:param learning_rates: the learning rates to be used during the training
:param iters: number of epochs using each of the supplied learning rates
:param output_path: path to save the trained model
:return:
"""
# Load data
train_loader, test_loader, _ = cifar10_loader(batch_size=128)
# Define loss
criterion = nn.CrossEntropyLoss()
for lr, iter in zip(learning_rates, iters):
print("Training with lr=%f for %d iters" % (lr, iter))
optimizer = optim.Adam(net.parameters(), lr=lr)
train_model(net, optimizer, criterion, train_loader, epochs=iter)
save_model(net, output_file=output_path)
def train_cifar_models():
"""
Trains the baselines teacher/students
:return:
"""
# ResNet training
net = ResNet18()
net.cuda()
train_cifar10_model(net, learning_rates=[0.001, 0.0001], iters=[50, 50],
output_path='models/resnet18_cifar10.model')
# Cifar Tiny
net = Cifar_Tiny()
net.cuda()
train_cifar10_model(net, learning_rates=[0.001, 0.0001], iters=[50, 50],
output_path='models/tiny_cifar10.model')
def evaluate_cifar_models_retrieval():
"""
Evaluates the baselines teacher/students
:return:
"""
evaluate_model_retrieval(net=Cifar_Tiny(num_classes=10), path='models/tiny_cifar10.model',
result_path='results/tiny_cifar10_baseline.pickle')
evaluate_model_retrieval(net=ResNet18(num_classes=10), path='models/resnet18_cifar10.model',
result_path='results/resnet18_cifar10_baseline.pickle')
if __name__ == '__main__':
# Training the teacher model takes approximately a day, so you can use the pretrained model
# train_cifar_models()
evaluate_cifar_models_retrieval()