-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
99 lines (80 loc) · 3.53 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Adapted from
# https://github.com/fra31/auto-attack/blob/master/autoattack/examples/eval.py
import os
import argparse
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
# import sys
# sys.path.insert(0,'..')
from models.wideresnet import *
from models.resnet import *
from ensemble import Ensemble
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--norm', type=str, default='Linf')
parser.add_argument('--epsilon', type=float, default=8./255.)
parser.add_argument('--model', type=str, default='./model_test.pt')
parser.add_argument('--n_ex', type=int, default=1000)
parser.add_argument('--individual', action='store_true')
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--log_path', type=str, default='./log_file.txt')
parser.add_argument('--version', type=str, default='standard')
parser.add_argument('--single', action='store_true', default=False)
args = parser.parse_args()
# load model
if args.single:
model = WideResNet()
model.load_state_dict(torch.load(args.model))
else:
m1 = WideResNet()
ckpt = torch.load(args.model + "-0.pt")
m1.load_state_dict(ckpt)
m2 = WideResNet()
ckpt = torch.load(args.model + "-1.pt")
m2.load_state_dict(ckpt)
m3 = WideResNet()
ckpt = torch.load(args.model + "-2.pt")
m3.load_state_dict(ckpt)
# model = Ensemble([m2])
model = Ensemble(m1, m2, m3)
model.cuda()
model.eval()
# load data
transform_list = [transforms.ToTensor()]
transform_chain = transforms.Compose(transform_list)
item = datasets.CIFAR10(root=args.data_dir, train=False, transform=transform_chain, download=True)
test_loader = data.DataLoader(item, batch_size=1000, shuffle=False, num_workers=0)
# create save dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# load attack
from autoattack import AutoAttack
adversary = AutoAttack(model, norm=args.norm, eps=args.epsilon, log_path=args.log_path,
version=args.version)
l = [x for (x, y) in test_loader]
x_test = torch.cat(l, 0)
l = [y for (x, y) in test_loader]
y_test = torch.cat(l, 0)
# example of custom version
if args.version == 'custom':
adversary.attacks_to_run = ['apgd-ce', 'fab']
adversary.apgd.n_restarts = 2
adversary.fab.n_restarts = 2
# run attack and save images
with torch.no_grad():
if not args.individual:
adv_complete = adversary.run_standard_evaluation(x_test[:args.n_ex], y_test[:args.n_ex],
bs=args.batch_size)
torch.save({'adv_complete': adv_complete}, '{}/{}_{}_1_{}_eps_{:.5f}.pth'.format(
args.save_dir, 'aa', args.version, adv_complete.shape[0], args.epsilon))
else:
# individual version, each attack is run on all test points
adv_complete = adversary.run_standard_evaluation_individual(x_test[:args.n_ex],
y_test[:args.n_ex], bs=args.batch_size)
torch.save(adv_complete, '{}/{}_{}_individual_1_{}_eps_{:.5f}_plus_{}_cheap_{}.pth'.format(
args.save_dir, 'aa', args.version, args.n_ex, args.epsilon))