-
Notifications
You must be signed in to change notification settings - Fork 26
/
utils_plus.py
121 lines (109 loc) · 4.33 KB
/
utils_plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#import apex.amp as amp
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
upper_limit, lower_limit = 1, 0
def clamp(X, lower_limit, upper_limit):
return torch.max(torch.min(X, upper_limit), lower_limit)
def get_loaders(dir_, batch_size, DATASET='CIFAR10'):
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
test_transform = transforms.Compose([
transforms.ToTensor()
])
num_workers = 2
if DATASET == 'CIFAR10':
train_dataset = datasets.CIFAR10(
dir_, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR10(
dir_, train=False, transform=test_transform, download=True)
elif DATASET == 'CIFAR100':
train_dataset = datasets.CIFAR100(
dir_, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR100(
dir_, train=False, transform=test_transform, download=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=num_workers,
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=2,
)
return train_loader, test_loader
def CW_loss(x, y):
x_sorted, ind_sorted = x.sort(dim=1)
ind = (ind_sorted[:, -1] == y).float()
loss_value = -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind))
return loss_value.mean()
def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, use_CWloss=False, normalize=None):
max_loss = torch.zeros(y.shape[0]).cuda()
max_delta = torch.zeros_like(X).cuda()
for _ in range(restarts):
delta = torch.zeros_like(X).cuda()
delta.uniform_(-epsilon, epsilon)
delta.data = clamp(delta, lower_limit - X, upper_limit - X)
delta.requires_grad = True
for _ in range(attack_iters):
output = model(normalize(X + delta))
index = torch.where(output.max(1)[1] == y)
if len(index[0]) == 0:
break
if use_CWloss:
loss = CW_loss(output, y)
else:
loss = F.cross_entropy(output, y)
loss.backward()
grad = delta.grad.detach()
d = delta[index[0], :, :, :]
g = grad[index[0], :, :, :]
d = torch.clamp(d + alpha * torch.sign(g), -epsilon, epsilon)
d = clamp(d, lower_limit - X[index[0], :, :, :], upper_limit - X[index[0], :, :, :])
delta.data[index[0], :, :, :] = d
delta.grad.zero_()
all_loss = F.cross_entropy(model(normalize(X + delta)), y, reduction='none').detach()
max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
max_loss = torch.max(max_loss, all_loss)
return max_delta
def evaluate_pgd(test_loader, model, attack_iters, restarts, eps=8, step=2, use_CWloss=False, normalize=None):
epsilon = eps / 255.
alpha = step / 255.
pgd_loss = 0
pgd_acc = 0
n = 0
model.eval()
for i, (X, y) in enumerate(test_loader):
X, y = X.cuda(), y.cuda()
pgd_delta = attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, use_CWloss=use_CWloss, normalize=normalize)
with torch.no_grad():
output = model(normalize(X + pgd_delta))
loss = F.cross_entropy(output, y)
pgd_loss += loss.item() * y.size(0)
pgd_acc += (output.max(1)[1] == y).sum().item()
n += y.size(0)
return pgd_loss/n, pgd_acc/n
def evaluate_standard(test_loader, model, normalize=None):
test_loss = 0
test_acc = 0
n = 0
model.eval()
with torch.no_grad():
for i, (X, y) in enumerate(test_loader):
X, y = X.cuda(), y.cuda()
output = model(normalize(X))
loss = F.cross_entropy(output, y)
test_loss += loss.item() * y.size(0)
test_acc += (output.max(1)[1] == y).sum().item()
n += y.size(0)
return test_loss/n, test_acc/n