-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
248 lines (215 loc) · 10.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import math
import eagerpy as ep
import numpy as np
import torch
from src.attacks.base import DirectionAttack, PerturbationAttack
from src.attacks.queries_counter import AttackPhase, QueriesCounter
from src.model_wrappers.general_model import ModelWrapper
def atleast_kd(x: ep.Tensor, k: int) -> ep.Tensor:
# From https://github.com/bethgelab/foolbox/blob/master/foolbox/devutils.py
shape = x.shape + (1, ) * (k - x.ndim)
return x.reshape(shape)
def flatten(x: ep.Tensor, keep: int = 1) -> ep.Tensor:
# From https://github.com/bethgelab/foolbox/blob/master/foolbox/devutils.py
return x.flatten(start=keep)
DEFAULT_LINE_SEARCH_TOL = 1e-5
MAX_BATCH_SIZE = 100
def opt_binary_search(attack: DirectionAttack | PerturbationAttack,
model: ModelWrapper,
x: torch.Tensor,
y: torch.Tensor,
target: torch.Tensor | None,
theta: torch.Tensor,
queries_counter: QueriesCounter,
initial_lbd: float,
phase: AttackPhase,
first_step_phase: AttackPhase | None = None,
tol: float = DEFAULT_LINE_SEARCH_TOL) -> tuple[float, QueriesCounter, float]:
lbd = initial_lbd
if isinstance(attack, DirectionAttack):
def is_correct_boundary_side_local(lbd_: float, qc: QueriesCounter) -> tuple[torch.Tensor, QueriesCounter]:
x_adv_ = attack.get_x_adv(x, theta, lbd_)
return attack.is_correct_boundary_side(model, x_adv_, y, target, qc, phase, x)
x_adv = attack.get_x_adv(x, theta, lbd)
elif isinstance(attack, PerturbationAttack):
def is_correct_boundary_side_local(lbd_: float, qc: QueriesCounter) -> tuple[torch.Tensor, QueriesCounter]:
x_adv_ = attack.get_x_adv(x, theta * lbd_)
return attack.is_correct_boundary_side(model, x_adv_, y, target, qc, phase, x)
x_adv = attack.get_x_adv(x, theta * lbd)
initial_phase = first_step_phase or phase
success, queries_counter = attack.is_correct_boundary_side(model, x_adv, y, target, queries_counter, initial_phase,
x)
if not success:
lbd_lo = lbd
lbd_hi = lbd * 1.02
while not (iter_result := is_correct_boundary_side_local(lbd_hi, queries_counter))[0].item():
_, queries_counter = iter_result
return lbd * 1.02, queries_counter, 1.02
else:
lbd_hi = lbd
lbd_lo = lbd * 0.99
while (iter_result := is_correct_boundary_side_local(lbd_lo, queries_counter))[0].item():
_, queries_counter = iter_result
lbd_lo *= 0.99
lbd_factor = lbd_hi / lbd
diff = lbd_hi - lbd_lo
while diff > tol:
lbd_mid = (lbd_lo + lbd_hi) / 2
# EDIT: add a break condition
if lbd_mid == lbd_hi or lbd_mid == lbd_lo:
break
success, queries_counter = is_correct_boundary_side_local(lbd_mid, queries_counter)
if success.item():
lbd_hi = lbd_mid
else:
lbd_lo = lbd_mid
# EDIT: This is to avoid numerical issue with gpu tensor when diff is small
if diff <= lbd_hi - lbd_lo:
break
diff = lbd_hi - lbd_lo
return lbd_hi, queries_counter, lbd_factor
def opt_line_search(attack: PerturbationAttack | DirectionAttack,
model: ModelWrapper,
x: torch.Tensor,
y: torch.Tensor,
target: torch.Tensor | None,
theta: torch.Tensor,
queries_counter: QueriesCounter,
initial_lbd: float,
phase: AttackPhase,
initial_phase: AttackPhase,
current_best: float | None,
n_searches: int,
max_search_steps: int | None,
batch_size: int,
lower_b: float | None = None,
upper_b: float | None = None,
step_size: float | None = None) -> tuple[float, QueriesCounter]:
if current_best is not None and initial_lbd > current_best:
if isinstance(attack, DirectionAttack):
x_adv = attack.get_x_adv(x, theta, current_best)
elif isinstance(attack, PerturbationAttack):
x_adv = attack.get_x_adv(x, theta * current_best)
success, queries_counter = attack.is_correct_boundary_side(model, x_adv, y, target, queries_counter,
initial_phase, x)
if not success.item():
return float('inf'), queries_counter
lbd = current_best
else:
lbd = initial_lbd
if lower_b is not None:
lower_lbd = lbd * lower_b
else:
lower_lbd = 0.
if upper_b is not None:
lbd = lbd * upper_b
assert n_searches in {1, 2}
if max_search_steps is None:
assert step_size is not None, 'Either step_size or max_search_steps must be specified'
if n_searches == 2:
search_max_steps = math.ceil(math.sqrt((lbd - lower_lbd) / step_size))
first_search_step_size = ((lbd - lower_lbd) / math.sqrt((lbd - lower_lbd) / step_size))
else:
search_max_steps = math.ceil(((lbd - lower_lbd) / step_size))
first_search_step_size = step_size
else:
assert step_size is None, 'Only one of step_size or max_search_steps must be specified'
if n_searches == 2:
search_max_steps = math.ceil(math.sqrt(max_search_steps))
else:
search_max_steps = max_search_steps
first_search_step_size = (lbd - lower_lbd) / search_max_steps
search_batch_size = min(search_max_steps, batch_size)
first_search_lbd, first_search_queries_counter, first_query_failed = _batched_line_search_body(
attack,
model,
x,
y,
target,
theta,
queries_counter,
lbd,
phase,
first_search_step_size,
search_batch_size,
# Here we count each query of the first search as equivalent to search_max_steps queries of when we do 1
equivalent_simulated_queries=search_max_steps,
# But we don't count the queries from the last batch as they will be counted in the second search
count_last_batch_for_sim=False)
if first_query_failed:
lbd_to_return = lbd * 1.02
if upper_b is not None:
print("Warning: line search overshoot was not enough")
return lbd_to_return, first_search_queries_counter
if n_searches == 2:
second_search_step_size = first_search_step_size / search_max_steps
final_lbd, second_search_queries_counter, _ = _batched_line_search_body(
attack,
model,
x,
y,
target,
theta,
first_search_queries_counter,
first_search_lbd,
phase,
second_search_step_size,
search_batch_size,
# Here each query has the same step size as if we were doing one search only
equivalent_simulated_queries=1,
# And we count the queries from the last batch as they are not counted in the first search
count_last_batch_for_sim=True)
else:
second_search_queries_counter = first_search_queries_counter
final_lbd = first_search_lbd
return final_lbd, second_search_queries_counter
def _batched_line_search_body(attack: PerturbationAttack | DirectionAttack,
model: ModelWrapper,
x: torch.Tensor,
y: torch.Tensor,
target: torch.Tensor | None,
theta: torch.Tensor,
queries_counter: QueriesCounter,
initial_lbd: float,
phase: AttackPhase,
step_size: float,
batch_size: int = MAX_BATCH_SIZE,
equivalent_simulated_queries: int = 1,
count_last_batch_for_sim: bool = False) -> tuple[float, QueriesCounter, bool]:
success = torch.tensor([True])
batch_idx = 0
lbds_inner_shape = tuple([1] * (len(x.shape) - 1))
previous_last_lbd = torch.tensor([initial_lbd])
lbds = np.array([initial_lbd])
while success.all():
# Update the last lbd (in case the whole next batch is unsafe) and the index
previous_last_lbd = lbds[-1]
# Get steps bounds based on the batch index
start = batch_idx * batch_size
end = (batch_idx + 1) * batch_size
# Compute the steps to do
steps_sizes = np.arange(start, end) * step_size
# Subtract the steps from the original distance
lbds = (initial_lbd - steps_sizes).reshape(-1, *lbds_inner_shape)
# Compute advex and query the model
lbds_torch = torch.from_numpy(lbds).float().to(device=x.device)
if isinstance(attack, DirectionAttack):
batch = attack.get_x_adv(x, theta, lbds_torch)
elif isinstance(attack, PerturbationAttack):
batch = attack.get_x_adv(x, theta * lbds_torch.unsqueeze(-1))
success, queries_counter = attack.is_correct_boundary_side_batched(model, batch, y, target, queries_counter,
phase, x, equivalent_simulated_queries,
count_last_batch_for_sim, batch_idx == 0)
batch_idx += 1
assert lbds is not None
# We get the index of the first unsafe query
unsafe_query_idx = torch.argmin(success.to(torch.int))
if unsafe_query_idx == 0:
# If no query was safe in the latest batch, then we return the last lbd from the previous batch
lbd = previous_last_lbd.item()
else:
lbd = lbds[unsafe_query_idx - 1].item()
# If we exited the loop after the first batch and the very first element was unsafe, then it means that
# the first query was unsafe
first_query_failed = batch_idx == 1 and bool((unsafe_query_idx == 0).item())
return lbd, queries_counter, first_query_failed