-
Notifications
You must be signed in to change notification settings - Fork 20
/
helper.py
253 lines (192 loc) · 8.76 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
Helper functions for training scVI
"""
import time
import numpy as np
import tensorflow as tf
from benchmarking import *
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
def format_time(seconds):
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return "%d:%02d:%02d" % (h, m, s)
def train_model(model, expression, sess, num_epochs, step=None, batch=None, kl=None):
expression_train, expression_test = expression
scVI_batch = batch is not None
if scVI_batch:
batch_train, batch_test = batch
if step is None:
step = model.train_step
batch_size = 128
iterep = int(expression_train.shape[0]/float(batch_size))-1
training_history = {"t_loss":[], "v_loss":[], "time":[], "epoch":[]}
training_history["n_hidden"] = model.n_hidden
training_history["model"] = model.__class__.__name__
training_history["n_input"] = model.n_input
training_history["dropout_nn"] = model.dropout_rate
training_history["dispersion"] = model.dispersion
training_history["n_layers"] = model.n_layers
if kl is None:
warmup = lambda x: np.minimum(1, x / 400.)
else:
warmup = lambda x: kl
begin = time.time()
for t in range(iterep * num_epochs):
# warmup
end_epoch, epoch = t % iterep == 0, t / iterep
kl = warmup(epoch)
# arange data in batches
index_train = np.random.choice(np.arange(expression_train.shape[0]), size=batch_size)
x_train = expression_train[index_train].astype(np.float32)
#prepare data dictionaries
dic_train = {model.expression: x_train, model.training_phase:True, model.kl_scale:kl}
dic_test = {model.expression: expression_test, model.training_phase:False, model.kl_scale:kl}
if scVI_batch:
b_train = batch_train[index_train]
dic_train[model.batch_ind] = b_train
dic_train[model.mmd_scale] = 10
dic_test[model.batch_ind] = batch_test
dic_test[model.mmd_scale] = 10
# run an optimization set
_, l_tr = sess.run([model.train_step, model.loss], feed_dict=dic_train)
if end_epoch:
now = time.time()
l_t = sess.run((model.loss), feed_dict=dic_test)
training_history["t_loss"].append(l_tr)
training_history["v_loss"].append(l_t)
training_history["time"].append(format_time(int(now-begin)))
training_history["epoch"].append(epoch)
if np.isnan(l_tr):
break
return training_history
def train_model_patience(model, expression, sess, num_epochs, step=None, batch=None, kl=None):
"""
This model implements customized training that stops after no progress within XXX iterations.
This is commonly called EarlyStopping in the litterature.
Here, we return the value of the \rho but keep training as in the scVI manuscript to show stability of these values
"""
expression_train, expression_test = expression
scVI_batch = batch is not None
if scVI_batch:
batch_train, batch_test = batch
if step is None:
step = model.train_step
batch_size = 128
iterep = int(expression_train.shape[0]/float(batch_size))-1
training_history = {"t_loss":[], "v_loss":[], "time":[], "epoch":[]}
training_history["n_hidden"] = model.n_hidden
training_history["model"] = model.__class__.__name__
training_history["n_input"] = model.n_input
training_history["dropout_nn"] = model.dropout_rate
training_history["dispersion"] = model.dispersion
training_history["n_layers"] = model.n_layers
if kl is None:
warmup = lambda x: np.minimum(1, x / 400.)
else:
warmup = lambda x: kl
begin = time.time()
# early stopping schedule
best_performance = np.inf
flag_has_exported_rho = False
wait = 0
rho_early = None
rho_final = None
for t in range(iterep * num_epochs):
# warmup
end_epoch, epoch = t % iterep == 0, t / iterep
kl = warmup(epoch)
# arange data in batches
index_train = np.random.choice(np.arange(expression_train.shape[0]), size=batch_size)
x_train = expression_train[index_train].astype(np.float32)
#prepare data dictionaries
dic_train = {model.expression: x_train, model.training_phase:True, model.kl_scale:kl}
dic_test = {model.expression: expression_test, model.training_phase:False, model.kl_scale:kl}
if scVI_batch:
b_train = batch_train[index_train]
dic_train[model.batch_ind] = b_train
dic_train[model.mmd_scale] = 10
dic_test[model.batch_ind] = batch_test
dic_test[model.mmd_scale] = 10
# run an optimization set
_, l_tr = sess.run([model.train_step, model.loss], feed_dict=dic_train)
if end_epoch:
now = time.time()
l_t = sess.run((model.loss), feed_dict=dic_test)
if wait >= 12:
if not flag_has_exported_rho:
print "scVI ran for " + str(epoch) + "epochs"
print "SAVING RHO"
rho_early, _, _ = eval_scale_params(model, expression_test, sess)
flag_has_exported_rho = True
else:
if l_t < best_performance:
# there is an improvement
best_performance = l_t
if l_t > best_performance:
wait += 1
else:
wait = 0
training_history["t_loss"].append(l_tr)
training_history["v_loss"].append(l_t)
training_history["time"].append(format_time(int(now-begin)))
training_history["epoch"].append(epoch)
print "epoch:", epoch, " l_t:", l_t, " wait:", wait
if np.isnan(l_tr):
break
print "SAVING RHO"
rho_final, _, _ = eval_scale_params(model, expression_test, sess)
return training_history, rho_early, rho_final
def eval_params(model, data, sess, batch=None):
dic_full = {model.expression: data, model.training_phase:False, model.kl_scale:1}
if batch is not None:
dic_full[model.batch_ind] = batch
dic_full[model.mmd_scale] = 0
rate, dropout = sess.run((model.px_rate, model.px_dropout), feed_dict=dic_full)
dispersion = np.tile(sess.run((tf.exp(model.px_r))), (rate.shape[0], 1))
return rate, dispersion, dropout
def eval_scale_params(model, data, sess, batch=None):
dic_full = {model.expression: data, model.training_phase:False, model.kl_scale:1}
if batch is not None:
dic_full[model.batch_ind] = batch
dic_full[model.mmd_scale] = 0
scale, dropout = sess.run((model.px_scale, model.px_dropout), feed_dict=dic_full)
dispersion = np.tile(sess.run((tf.exp(model.px_r))), (scale.shape[0], 1))
return scale, dispersion, dropout
def eval_imputed_data(model, corrupted_info, expression_train, sess, batch=None):
(X_zero, i0, j0, ix0) = corrupted_info
dic_zero = {model.expression: X_zero, model.training_phase:False, model.kl_scale:1.}
if batch is not None:
dic_zero[model.batch_ind] = batch
dic_zero[model.mmd_scale] = 0
rate_ = sess.run((model.px_rate), \
feed_dict=dic_zero)
return imputation_error(rate_, expression_train, X_zero, i0, j0, ix0)
def eval_likelihood(model, data, sess, batch=None):
dic_full = {model.expression: data, model.training_phase:False, model.kl_scale:1}
if batch is not None:
dic_full[model.batch_ind] = batch
dic_full[model.mmd_scale] = 0
return sess.run(model.loss, feed_dict=dic_full)
def eval_latent(model, data, sess, batch=None):
dic_full = {model.expression: data, model.training_phase:False, model.kl_scale:1}
if batch is not None:
dic_full[model.batch_ind] = batch
dic_full[model.mmd_scale] = 0
return sess.run(model.z, feed_dict=dic_full)
def plot_training_info(result):
plt.plot(result["epoch"], result["t_loss"])
plt.plot(result["epoch"], result["v_loss"])
plt.xlabel("number of epochs")
plt.ylabel("objective function")
plt.tight_layout()
def show_tSNE(latent, labels, cmap=plt.get_cmap("tab10", 7), return_tSNE=False):
if latent.shape[1] != 2:
latent = TSNE().fit_transform(latent)
plt.figure(figsize=(10, 10))
plt.scatter(latent[:, 0], latent[:, 1], c=labels, \
cmap=cmap, edgecolors='none')
plt.axis("off")
plt.tight_layout()
if return_tSNE:
return latent