forked from nh2tran/DeepNovo
-
Notifications
You must be signed in to change notification settings - Fork 1
/
deepnovo_worker_db.py
487 lines (400 loc) · 18.5 KB
/
deepnovo_worker_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
# Copyright 2017 Hieu Tran. All Rights Reserved.
#
# DeepNovo is publicly available for non-commercial uses.
# ==============================================================================
"""TODO(nh2tran): docstring."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import time
import re
from random import shuffle
from Bio import SeqIO
from pyteomics import parser
import numpy as np
import tensorflow as tf
import deepnovo_config
from deepnovo_cython_modules import get_candidate_intensity
class WorkerDB(object):
"""TODO(nh2tran): docstring.
This class contains the database search module.
We use "db" for "database".
We use "pepmod" to refer to a modified version of a "peptide"
"""
def __init__(self):
"""TODO(nh2tran): docstring."""
print("".join(["="] * 80)) # section-separating line
print("WorkerDB: __init__()")
# we currently use deepnovo_config to store both const & settings
# the settings should be shown in __init__() to keep track carefully
# input info to build a db
self.db_fasta_file = deepnovo_config.db_fasta_file
self.cleavage_rule = deepnovo_config.cleavage_rule
self.num_missed_cleavage = deepnovo_config.num_missed_cleavage
self.fixed_mod_list = deepnovo_config.fixed_mod_list
self.var_mod_list = deepnovo_config.var_mod_list
self.precursor_mass_tolerance = deepnovo_config.precursor_mass_tolerance
self.precursor_mass_ppm = deepnovo_config.precursor_mass_ppm
self.decoy = deepnovo_config.FLAGS.decoy
print("db_fasta_file = {0:s}".format(self.db_fasta_file))
print("cleavage_rule = {0:s}".format(self.cleavage_rule))
print("num_missed_cleavage = {0:d}".format(self.num_missed_cleavage))
print("fixed_mod_list = {0}".format(self.fixed_mod_list))
print("var_mod_list = {0}".format(self.var_mod_list))
print("precursor_mass_tolerance = {0:.4f}".format(self.precursor_mass_tolerance))
print("precursor_mass_ppm = {0:.6f}".format(self.precursor_mass_ppm))
# data structure to store a db
# all attributes will be built/loaded by build_db()
self.peptide_count = None
self.peptide_list = None
self.peptide_mass_array = None
self.pepmod_maxmass_array = None
def build_db(self):
"""TODO(nh2tran): docstring."""
print("".join(["="] * 80)) # section-separating line
print("WorkerDB: build_db()")
# parse the input fasta file into a list of sequences
# more about SeqIO and SeqRecord: http://biopython.org/wiki/SeqRecord
with open(self.db_fasta_file, "r") as handle:
record_iterator = SeqIO.parse(handle, "fasta")
record_list = list(record_iterator)
sequence_list = [str(record.seq) for record in record_list]
print("Number of protein sequences: {0:d}".format(len(sequence_list)))
# cleave protein sequences into a list of unique peptides
# more about pyteomics.parser.cleave and cleavage rules:
# https://pythonhosted.org/pyteomics/api/parser.html
peptide_set = set()
for sequence in sequence_list:
peptide_set.update((parser.cleave(
sequence=sequence,
rule=parser.expasy_rules[self.cleavage_rule],
missed_cleavages=self.num_missed_cleavage)))
peptide_list = list(peptide_set)
# skip peptides with undetermined amino acid 'X', or 'B'
peptide_list = [list(peptide) for peptide in peptide_list
if not any(x in peptide for x in ['X', 'B', 'U', 'Z'])]
peptide_count = len(peptide_list)
print("Number of peptides: {0:d}".format(peptide_count))
# replace "L" by "I"
for index, peptide in enumerate(peptide_list):
peptide = ['I' if x == 'L' else x for x in peptide]
peptide_list[index] = peptide
# update fixed modifications
for index, peptide in enumerate(peptide_list):
peptide = [x + 'mod' if x in self.fixed_mod_list else x for x in peptide]
peptide_list[index] = peptide
# for each peptide, find the mass and the max modification mass
peptide_mass_array = np.zeros(peptide_count)
pepmod_maxmass_array = np.zeros(peptide_count)
for index, peptide in enumerate(peptide_list):
peptide_mass_array[index] = self._compute_peptide_mass(peptide)
pepmod = [x + 'mod' if x in self.var_mod_list else x for x in peptide]
pepmod_maxmass_array[index] = self._compute_peptide_mass(pepmod)
self.peptide_count = peptide_count
self.peptide_list = peptide_list
self.peptide_mass_array = peptide_mass_array
self.pepmod_maxmass_array = pepmod_maxmass_array
def search_db(self, model, worker_io, predicted_denovo_list=None):
"""TODO(nh2tran): docstring."""
print("".join(["="] * 80)) # section-separating line
print("WorkerDB: search_db()")
# move load/build db here?
# if provided, convert predicted_denovo_list to dictionary for easy lookup
denovo_peptide_dict = None
if predicted_denovo_list is not None:
print("WorkerDB: search_db() - read denovo peptides")
denovo_peptide_dict = {}
for predicted in predicted_denovo_list:
scan = predicted["scan"]
sequence = predicted["sequence"]
denovo_peptide_dict[scan] = sequence
print("WorkerDB: search_db() - open tensorflow session")
session = tf.Session()
model.restore_model(session)
worker_io.open_input()
worker_io.get_location()
worker_io.split_location()
worker_io.open_output()
print("".join(["="] * 80)) # section-separating line
print("WorkerDB: search_db() - search loop")
for index, location_batch in enumerate(worker_io.location_batch_list):
print("Read {0:d}/{1:d} batches".format(index + 1,
worker_io.location_batch_count))
spectrum_batch = worker_io.get_spectrum(location_batch)
predicted_batch = self._search_db_batch(spectrum_batch,
model,
session,
denovo_peptide_dict)
worker_io.write_prediction(predicted_batch)
print("Total spectra: {0:d}".format(worker_io.spectrum_count["total"]))
print(" read: {0:d}".format(worker_io.spectrum_count["read"]))
print(" skipped: {0:d}".format(worker_io.spectrum_count["skipped"]))
print(" by mass: {0:d}".format(worker_io.spectrum_count["skipped_mass"]))
worker_io.close_input()
worker_io.close_output()
session.close()
def _compute_peptide_mass(self, peptide):
"""TODO(nh2tran): docstring.
"""
#~ print("".join(["="] * 80)) # section-separating line ===
#~ print("WorkerDB: _compute_peptide_mass()")
peptide_mass = (deepnovo_config.mass_N_terminus
+ sum(deepnovo_config.mass_AA[aa] for aa in peptide)
+ deepnovo_config.mass_C_terminus)
return peptide_mass
def _expand_peptide_modification(self, peptide):
"""TODO(nh2tran): docstring.
May also use parser.isoforms
"""
#~ print("".join(["="] * 80)) # section-separating line
#~ print("WorkerDB: _expand_peptide_modification()")
# recursively add all modifications
pepmod_list = [peptide] # the first entry without any modifications
mod_count = 0
for position, aa in enumerate(peptide):
if aa in self.var_mod_list:
mod_count += 1
# add modification of this position to all peptides in the current list
new_mod_list = []
for pepmod in pepmod_list:
new_mod = pepmod[:]
new_mod[position] = aa + 'mod'
new_mod_list.append(new_mod)
pepmod_list = pepmod_list + new_mod_list
# sanity check of the recursive iteration
assert len(pepmod_list) == pow(2, mod_count), (
"Wrong peptide expansion!")
return pepmod_list
def _filter_by_mass(self, precursor_mass):
"""TODO(nh2tran): docstring.
"""
#~ print("".join(["="] * 80)) # section-separating line
#~ print("WorkerDB: _filter_by_mass()")
# use precursor_mass_ppm instead of absolute precursor_mass_tolerance
#~ precursor_mass_tolerance = self.precursor_mass_tolerance
precursor_mass_tolerance = self.precursor_mass_ppm * precursor_mass
# 1st filter by the peptide mass and the max pepmod mass
filter1_index = np.flatnonzero(np.logical_and(
np.less_equal(self.peptide_mass_array,
precursor_mass + precursor_mass_tolerance),
np.greater_equal(self.pepmod_maxmass_array,
precursor_mass - precursor_mass_tolerance)))
# find all possible modifications
pepmod_list = []
for index in filter1_index:
peptide = self.peptide_list[index]
pepmod_list += self._expand_peptide_modification(peptide)
pepmod_mass_array = np.array([self._compute_peptide_mass(pepmod)
for pepmod in pepmod_list])
# 2nd filter by exact pepmod mass
filter2_index = np.flatnonzero(np.logical_and(
np.less_equal(pepmod_mass_array,
precursor_mass + precursor_mass_tolerance),
np.greater_equal(pepmod_mass_array,
precursor_mass - precursor_mass_tolerance)))
candidate_list = [pepmod_list[x] for x in filter2_index]
return candidate_list
def _score_spectrum(self,
precursor_mass,
spectrum_original,
state0_c,
state0_h,
candidate_list,
model,
model_output_logprob,
model_lstm_state,
session,
direction):
"""TODO(nh2tran): docstring."""
#~ print("".join(["="] * 80)) # section-separating line
#~ print("WorkerDB: _score()")
# convert symbols into id
candidate_list = [[deepnovo_config.vocab[x] for x in candidate]
for candidate in candidate_list]
# we shall group candidates into minibatches
# === candidate_len ===
# s
# i
# z
# e
# =====================
minibatch_size = len(candidate_list) # number of candidates
candidate_len = len(candidate_list[0]) # length of each candidate
# candidates share the same state0, so repeat into [minibatch_size, 512]
# the states will also be updated after every iteration
state0_c = state0_c.reshape((1, -1)) # reshape to [1, 512]
state0_h = state0_h.reshape((1, -1))
minibatch_state_c = np.repeat(state0_c, minibatch_size, axis=0)
minibatch_state_h = np.repeat(state0_h, minibatch_size, axis=0)
# mass of each candidate, will be accumulated everytime an AA is appended
minibatch_prefix_mass = np.zeros(minibatch_size)
# output is a list of candidate_len arrays of shape [minibatch_size, 26]
# each row is log of probability distribution over 26 classes/symbols
output_logprob_list = []
# recurrent iterations
for position in range(candidate_len):
# gather minibatch data
minibatch_AA_id = np.zeros(minibatch_size)
for index, candidate in enumerate(candidate_list):
AA = candidate[position]
minibatch_AA_id[index] = AA
minibatch_prefix_mass[index] += deepnovo_config.mass_ID[AA]
# this is the most time-consuming ~70-75%
minibatch_intensity = [get_candidate_intensity(spectrum_original,
precursor_mass,
prefix_mass,
direction)
for prefix_mass in np.nditer(minibatch_prefix_mass)]
# final shape [minibatch_size, 26, 8, 10]
minibatch_intensity = np.array(minibatch_intensity)
# model feed
input_feed = {}
input_feed[model.input_dict["AAid"][1].name] = minibatch_AA_id
input_feed[model.input_dict["intensity"].name] = minibatch_intensity
input_feed[model.input_dict["lstm_state"][0].name] = minibatch_state_c
input_feed[model.input_dict["lstm_state"][1].name] = minibatch_state_h
# and run
output_feed = [model_output_logprob, model_lstm_state]
output_logprob, (minibatch_state_c, minibatch_state_h) = session.run(
fetches=output_feed,
feed_dict=input_feed)
output_logprob_list.append(output_logprob)
return output_logprob_list
def _search_db_batch(self,
spectrum_batch,
model,
session,
denovo_peptide_dict):
"""TODO(nh2tran): docstring.
Inputs:
spectrum_batch: a list of spectrum, each is a dictionary
spectrum["scan"]
spectrum["precursor_mass"]
spectrum["spectrum_holder"]
spectrum["spectrum_original_forward"]
spectrum["spectrum_original_backward"]
Outputs:
predicted_batch: a list of predicted, each is a dictionary
predicted["scan"]
predicted["sequence"]
predicted["score"]
predicted["position_score"]
"""
#~ print("".join(["="] * 80)) # section-separating line
#~ print("WorkerDB: _search_db_batch()")
# initialize the lstm using the spectrum
# for faster speed, we initialize the whole spectrum_batch instead of 1-by-1
input_feed = {}
spectrum_holder = np.array([spectrum["spectrum_holder"]
for spectrum in spectrum_batch])
input_feed[model.input_dict["spectrum"].name] = spectrum_holder
output_feed = [model.output_forward["lstm_state0"],
model.output_backward["lstm_state0"]]
((state0_c_forward, state0_h_forward),
(state0_c_backward, state0_h_backward)) = session.run(fetches=output_feed,
feed_dict=input_feed)
predicted_batch = []
# we search spectrum by spectrum
# a faster way is to process them in parallel, but hard to debug
for spectrum_index, spectrum in enumerate(spectrum_batch):
predicted = {"scan": spectrum["scan"],
"sequence": [],
"score": -float("inf"),
"position_score": []}
# filter by precursor mass
# example: [['M', 'D', 'K', 'F', 'Nmod', 'K', 'K']]
precursor_mass = spectrum["precursor_mass"]
candidate_list = self._filter_by_mass(precursor_mass)
# add denovo peptide if provided
scan = spectrum["scan"]
if denovo_peptide_dict is not None and scan in denovo_peptide_dict:
sequence = denovo_peptide_dict[scan]
# TODO(nh2tran): change the precursor_mass_tolerance of denovo
sequence_mass = self._compute_peptide_mass(sequence)
precursor_mass_tolerance = precursor_mass * self.precursor_mass_ppm
if abs(precursor_mass - sequence_mass) <= precursor_mass_tolerance:
candidate_list.append(sequence)
# if no candidate found, return empty sequence for this spectrum.
if not candidate_list:
predicted_batch.append(predicted)
continue
# if decoy is activated, randomly shuffle amino acids to form decoy db.
if self.decoy:
for x in candidate_list:
shuffle(x) # this function works in place and returns None.
# add special GO/EOS and reverse
# example: [['_GO', 'M', 'D', 'K', 'F', 'Nmod', 'K', 'K', '_EOS']]
candidate_forward_list = [[deepnovo_config._GO] + x + [deepnovo_config._EOS]
for x in candidate_list]
candidate_backward_list = [x[::-1] for x in candidate_forward_list]
# add PAD to all candidates to the same max length
# [['_GO', 'M', 'D', 'K', 'F', 'Nmod', 'K', 'K', '_EOS', '_PAD', '_PAD']]
# due to the same precursor mass, candidates have very similar lengths
candidate_len_list = [len(x) for x in candidate_list]
candidate_maxlen = max(candidate_len_list)
for index, length in enumerate(candidate_len_list):
if length < candidate_maxlen:
pad_size = candidate_maxlen - length
candidate_forward_list[index] += [deepnovo_config._PAD] * pad_size
candidate_backward_list[index] += [deepnovo_config._PAD] * pad_size
# score the spectrum against its candidates
# using the forward model
logprob_forward_list = self._score_spectrum(
spectrum["precursor_mass"],
spectrum["spectrum_original_forward"],
state0_c_forward[spectrum_index],
state0_h_forward[spectrum_index],
candidate_forward_list,
model,
model.output_forward["logprob"],
model.output_forward["lstm_state"],
session,
direction=0)
# and using the backward model
logprob_backward_list = self._score_spectrum(
spectrum["precursor_mass"],
spectrum["spectrum_original_backward"],
state0_c_backward[spectrum_index],
state0_h_backward[spectrum_index],
candidate_backward_list,
model,
model.output_backward["logprob"],
model.output_backward["lstm_state"],
session,
direction=1)
# note that the candidates are grouped into minibatches
# === candidate_len ===
# s
# i
# z
# e
# =====================
# logprob_forward_list is a list of candidate_maxlen arrays of shape
# [minibatch_size, 26]
# each row is log of probability distribution over 26 classes/symbols
# find the best scoring candidate
for index, candidate in enumerate(candidate_list):
# only calculate score on the actual length, not on GO/EOS/PAD
candidate_len = candidate_len_list[index]
# align forward and backward logprob
logprob_forward = [logprob_forward_list[position][index]
for position in range(candidate_len)]
logprob_backward = [logprob_backward_list[position][index]
for position in range(candidate_len)]
logprob_backward = logprob_backward[::-1]
# score is the sum of logprob(AA) of the candidate in both directions
# averaged by the candidate length
position_score = []
for position in range(candidate_len):
AA = candidate[position]
AA_id = deepnovo_config.vocab[AA]
position_score.append(logprob_forward[position][AA_id]
+ logprob_backward[position][AA_id])
score = sum(position_score) / candidate_len
if score > predicted["score"]:
predicted["sequence"] = candidate
predicted["score"] = score
predicted["position_score"] = position_score
predicted_batch.append(predicted)
return predicted_batch