-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtools.py
425 lines (355 loc) · 15.1 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
""" Tools for protein engineering bandit simulations.
Available functions:
video_frame: Make a single frame of a learning video for a trial.
make_video: Finalizes video creation and shows it. (Must be called after
all video_frame calls and before any other plotting function).
progress_graph: Plots graph of a trial's chosen and maximum T50 over time.
get_filled_T50s: Converts saved simultion data into uniform T50 data.
itrs_to_temp_boxplot: Makes a boxplot of the number of iterations to reach
a given temp.
gp_class: Convenience wrapper for Gaussian Process Classification.
gp_reg: Convenience for Gaussian Process Regression.
"""
from matplotlib import animation
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor, \
GaussianProcessClassifier
class DP(object):
"""Represents each datapoint in the dataset.
pred_T50, std, ucb, and prob may be meaningless depending on the context
of the current trial iteration.
Public attributes:
seq: List of 1s and 0s denoting an amino acid sequnce. (read-only)
T50: Float of the T50 associated with seq. Will be float('NaN') for
inactive seqs. (read-only)
index: datapoints List index of the DP. (read-only)
explored: Whether the DP has been chosen in the current trial.
pred_T50: Predicted T50 of the DP in the current trial iteration.
std: Standard error of the T50 prediction.
ucb: Upper confidence bound for the unchosen_T50 in the current trial
iteration.
prob: Functionality probability predicted for the DP in the current
trial iteration.
"""
def __init__(self, seq, T50, index):
self._seq = seq
self._T50 = T50
self._dps_index = index
self.explored = False
self.pred_T50 = 50
self.std = 15
self.ucb = self.pred_T50 + 2*self.std
self.prob = 1
@property
def seq(self):
return self._seq
@property
def T50(self):
return self._T50
@property
def dps_index(self):
return self._dps_index
def video_frame(datapoints, possible_seqs, chosen_seq, fig, final=False):
"""Make a frame of a trial learning video.
Trial learning videos show the real T50s of each datapoint, the previously
selected sequences, and the UCB scores of the possible sequences.
video_frame() should be called on every iteration of a trial. Set final to
true on the final frame and then call make_video(). No plotting should be
done between the first video_frame() call and make_video().
Args:
datapoints: List of DP objects. The ordering should stay constant
between calls.
possible_seqs: List of seqs for which ucb should be plotted.
chosen_seq: DP object called on current round.
fig: plt Figure object where video is plotted. This must be the same
object on every video_frame call.
final: Whether the current call is the final video_frame() call. Adds
a legend to the video.
Returns:
List of Lists of artists, which are the plt objects drawn. Needed for
make_video() call.
"""
plt.ylim(20, max([ps.ucb for ps in possible_seqs]) + 3)
plt.ylim(20, 90)
explored = [(dp.dps_index, dp.T50) for dp in datapoints if dp.explored]
explored_pos = [x for x in explored if not np.isnan(x[1])]
explored_neg = [(i, 21) for i, T50 in explored if np.isnan(T50)]
unexplored = [(dp.dps_index, dp.T50) for dp in datapoints
if not dp.explored]
unexplored_pos = [x for x in unexplored if not np.isnan(x[1])]
unexplored_neg = [(i, 21) for i, T50 in unexplored if np.isnan(T50)]
ucb = [(dp.dps_index, dp.ucb) for dp in possible_seqs]
cs = (chosen_seq.dps_index, chosen_seq.ucb)
l0, = plt.plot(*zip(*unexplored_pos), 'ro')
l1, = plt.plot(*zip(*unexplored_neg), 'rx')
l2, = plt.plot(*zip(*ucb), 'go')
l3, = plt.plot(*zip(*explored_pos), 'bo')
l4, = plt.plot(*zip(*explored_neg), 'bx')
l5, = plt.plot(*cs, 'yo')
if explored_pos:
max_seq = max([x[1] for x in explored_pos])
l6 = plt.axhline(max_seq, 0, len(datapoints))
artists = [l0, l1, l2, l3, l4, l5, l6]
else:
artists = [l0, l1, l2, l3, l4, l5]
if final:
l0.set_label('unexplored positive')
l1.set_label('unexplored negative')
l2.set_label('predicted positive ucb')
l3.set_label('explored positive')
l4.set_label('explored negative')
l5.set_label('current selected ucb')
plt.legend()
return artists
def make_video(fig, artists, show_video=False, save_video_fn=None,
**ani_kwargs):
"""Makes the trial learning video.
Args:
fig: plt Figure object used for video. Must be same fig passed into
video_frame() calls.
artists: List of Lists of plt artists. Returned by video_frame().
show_video: Whether plt.show() is called to display video.
save_video_fn: Filename of saved video. Video will not be saved if
None.
**ani_kwargs: List of keywords to pass into ArtistAnimation()
initialization. Look at ArtistAnimation documentation for a full
list.
"""
ani = animation.ArtistAnimation(fig, artists, **ani_kwargs)
if save_video_fn:
ani.save(save_video_fn)
if show_video:
plt.show()
return ani
def progress_graph(chosen_seqs):
"""Make a graph that shows the progress of a trial over time.
The progress graph is a plot of the selected T50 and maximum T50 at each
iteration.
Args:
chosen_seqs: List of DP objects in order of selection by trial.
"""
max_T50s = []
curr_T50s = []
for cs in chosen_seqs:
if np.isnan(cs.T50):
cs_T50 = 0
else:
cs_T50 = cs.T50
curr_T50s.append(cs_T50)
if not max_T50s or cs.T50 > max_T50s[-1]:
max_T50s.append(cs_T50)
else:
max_T50s.append(max_T50s[-1])
plt.figure()
plt.plot(max_T50s, 'b-o')
curr_pos = [(i, c) for i, c in enumerate(curr_T50s) if c]
curr_neg = [(i, 35) for i, c in enumerate(curr_T50s) if not c]
plt.plot(*zip(*curr_pos), 'ro')
plt.plot(*zip(*curr_neg), 'rx')
plt.show(block=True)
def get_filled_T50s(run_data, length):
"""Gets T50s for data generated by simulations.
Args:
run_data: List of Lists of Lists of DP objects representing chosen
sequences of each trial (3rd list layer) of each strategy (2nd list
layer) of a given run (outer list).
length: Float specifying how far to extend all trial lengths with
repetition of the last chosen sequence.
Returns:
List of Lists of Lists of floats that is the same as run_data but each
trial length in a strategy is a uniform length and thermostability
replaces the DP object.
"""
run_T50s = []
for strat_data in run_data:
strat_T50s = []
max_num_itrs = max(length, max([len(td) for td in strat_data]))
for trial_data in strat_data:
trial_T50s = [itr_cs.T50 for itr_cs in trial_data]
for _ in range(max_num_itrs - len(trial_data)):
trial_T50s.append(trial_T50s[-1])
strat_T50s.append(trial_T50s)
run_T50s.append(strat_T50s)
return run_T50s
def _get_max_T50s(trial_T50s):
"""Helper function that gets the maximum T50 at each iteration of a trial.
Args:
trial_T50s: List of selected T50s at each iteration of a trial.
Returns:
Numpy array of maximum T50 selected at any round before or during the
given iteration.
"""
curr_max_t50 = float('NaN')
max_T50s = []
for curr_T50 in trial_T50s:
if not np.isnan(curr_T50):
if curr_T50 > curr_max_t50 or np.isnan(curr_max_t50):
curr_max_t50 = curr_T50
max_T50s.append(curr_max_t50)
return np.array(max_T50s)
def learning_curve(run_T50s, strategies, save_fn=None, show=False, ylines=[], xlims=[]):
"""Plots curves of average maximum T50 data against number of iterations.
Args:
run_T50s: List of Lists (strategies) of Lists (trials) of floats which
are the T50s of the sequences chosen at each iteration. Trial Lists
must be a uniform length within a given strategy List.
strategies: List of Strings for strategies names corresponding to the
strategy Lists in run_T50s.
save_fn: Filename (without extension) to save plot. Plot will be saved
as both png and eps. None means saving will not occur.
show: Whether plt.show() is called at the end of function.
Returns:
Figure that was created.
"""
fig = plt.figure()
plt.xlabel('iteration')
plt.ylabel('maximum T50')
for strat_name, strat_T50s in zip(strategies, run_T50s):
strat_max_T50s = [_get_max_T50s(trial_data) for trial_data
in strat_T50s]
untrimmed_step_data = np.array(strat_max_T50s).T
step_data = tuple(x[~np.isnan(x)] for x in untrimmed_step_data)
strat_means = [np.mean(x) for x in step_data]
plt.plot(strat_means, label=strat_name)
plt.legend()
# for y in ylines:
# plt.plot(xlims,[y,y])
plt.xlim(xlims)
if save_fn:
plt.savefig(save_fn + '.png')
plt.savefig(save_fn + '.eps')
if show:
plt.show()
return fig
def _get_itrs_to_temp(T50s, temp):
"""Helper function that returns the number of iterations needed to find a
temperature above temp.
Args:
T50s: List of T50s selected at each iteration.
temp: Float temperature threshold to search for.
"""
for itr, curr_T50 in enumerate(T50s):
if not np.isnan(curr_T50) and curr_T50 >= temp:
return itr
return len(T50s)
def itrs_to_temp_boxplot(run_T50s, strategies, temp_thresh, save_fn=None,
show=False):
"""Makes a boxplot of the number of iterations needed to reach the given
temperature.
Args:
run_T50s: List of Lists (strategies) of Lists (trials) of floats which
are the T50s of the sequences chosen at each iteration. Trial Lists
must be a uniform length within a given strategy List.
strategies: List of Strings for strategies names corresponding to the
strategy Lists in run_T50s.
temp_thresh: Temperature threshold for number of iterations to be
calculated.
save_fn: Filename (without extension) to save plot. Plot will be saved
as both png and eps. None means saving will not occur.
show: Whether plt.show() is called at the end of function.
Returns:
Figure that was created.
"""
fig = plt.figure()
plt.ylabel(f'trials needed to reach {temp_thresh}')
itrs_to_thresh = []
for strat_T50s in run_T50s:
strat_itrs = []
for trial_T50s in strat_T50s:
trial_itrs = _get_itrs_to_temp(trial_T50s, temp_thresh)
strat_itrs.append(trial_itrs)
itrs_to_thresh.append(strat_itrs)
plt.boxplot(itrs_to_thresh, labels=strategies)
#plt.violinplot(itrs_to_thresh)
plt.yscale('log')
plt.ylim([0.8, 1000])
if save_fn:
plt.savefig(save_fn + '.png')
plt.savefig(save_fn + '.eps')
if show:
plt.show()
return fig
def itrs_to_temp_hist(run_T50s, strategies, temp_thresh, save_fn=None,show=False):
"""Makes a boxplot of the number of iterations needed to reach the given
temperature.
Args:
run_T50s: List of Lists (strategies) of Lists (trials) of floats which
are the T50s of the sequences chosen at each iteration. Trial Lists
must be a uniform length within a given strategy List.
strategies: List of Strings for strategies names corresponding to the
strategy Lists in run_T50s.
temp_thresh: Temperature threshold for number of iterations to be
calculated.
save_fn: Filename (without extension) to save plot. Plot will be saved
as both png and eps. None means saving will not occur.
show: Whether plt.show() is called at the end of function.
Returns:
Figure that was created.
"""
fig = plt.figure()
plt.ylabel(f'trials needed to reach {temp_thresh}')
itrs_to_thresh = []
for strat_T50s in run_T50s:
strat_itrs = []
for trial_T50s in strat_T50s:
trial_itrs = _get_itrs_to_temp(trial_T50s, temp_thresh)
strat_itrs.append(trial_itrs)
itrs_to_thresh.append(strat_itrs)
itrs_to_thresh = [[j+1 for j in i] for i in itrs_to_thresh] # min should be 1 not zero
sns.kdeplot(itrs_to_thresh[0],log_scale=True, bw_adjust=1.5)
sns.kdeplot(itrs_to_thresh[1],log_scale=True, bw_adjust=1.5)
sns.kdeplot(itrs_to_thresh[2],log_scale=True, bw_adjust=1.5)
sns.kdeplot(itrs_to_thresh[3],log_scale=True, bw_adjust=1.5)
plt.legend(['random_selection', 'ucb', 'pp_ucb', 'expected_ucb'])
plt.xlim([0.9, 1000])
# med = [np.median(s) for s in itrs_to_thresh]
# plt.plot([med[0],med[0]],[0,2])
# plt.plot([med[1],med[1]],[0,2])
# plt.plot([med[2],med[2]],[0,2])
# plt.plot([med[3],med[3]],[0,2])
if save_fn:
plt.savefig(save_fn + '.png')
plt.savefig(save_fn + '.eps')
if show:
plt.show()
return fig
def gp_class(train_data, test_dps, kernel):
"""Wrapper for Gaussian Process Classification.
Args:
train_data: List of (binary_sequence, binary_functionality) for each
sequence previously selected.
test_dps: List of DP object to predict functionality on.
kernel: sklearn.gaussian_process kernel to use in the GPC. See
sklearn documentation for full list.
Returns:
List of functionality probabilities for each DP in test_dps.
"""
if len(set([d[1] for d in train_data]))==2: # two classes
gpc = GaussianProcessClassifier(kernel=kernel)
gpc.fit(*zip(*train_data))
test_x = [dp.seq for dp in test_dps]
y_prob = gpc.predict_proba(test_x)
prob_func = [p[1] for p in y_prob]
else:
prob_func = [1 for p in test_dps] # one class: just assign p=1 to all sequences
return prob_func
def gp_reg(train_data, test_dps, kernel):
"""Wrapper for Gaussian Process Regression.
Args:
train_data: List of (binary_sequence, binary_functionality) for each
sequence previously selected.
test_dps: List of DP object to predict thermostability on.
kernel: sklearn.gaussian_process kernel to use in the GPC. See
sklearn documentation for full list.
Returns:
y_mean: List of predicted T50s for each DP in test_dps.
y_std: List of standard deviations of T50 prediction for each DP in
test_DPS.
"""
gpr = GaussianProcessRegressor(kernel=kernel)
gpr.fit(*zip(*train_data))
test_x = [dp.seq for dp in test_dps]
y_mean, y_std = gpr.predict(test_x, return_std=True)
return y_mean, y_std