-
Notifications
You must be signed in to change notification settings - Fork 1
/
classify.py
131 lines (100 loc) · 5.26 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
This inference procedure is an extended and adapted version of the one used in the SquiggleNet project, see
https://github.com/welch-lab/SquiggleNet/blob/master/inference.py. In addition to the original inference procedure, it
performs read padding like done for the train and validation data (see prepare_training.py) and uses a custom decision
threshold. In addition, it makes use of a certain number of reads per batch instead of a certain number of files per
batch.
"""
import click
import glob
import numpy as np
import os
import pandas as pd
import torch
from model import Bottleneck, ResNet
from ont_fast5_api.fast5_interface import get_fast5_file
from scipy import stats
PLASMID_LABEL = 0
def append_read(read, reads, read_ids):
reads.append(read.get_raw_data(scale=True))
read_ids.append(read.read_id)
return reads, read_ids
def normalize(data, batch_idx, consistency_correction=1.4826):
extreme_signals = list()
for r_i, read in enumerate(data):
# normalize using z-score with median absolute deviation
median = np.median(read)
mad = stats.median_abs_deviation(read, scale='normal')
data[r_i] = list((read - median) / (consistency_correction * mad))
# get extreme signals (modified absolute z-score larger than 3.5)
# see Iglewicz and Hoaglin (https://hwbdocuments.env.nm.gov/Los%20Alamos%20National%20Labs/TA%2054/11587.pdf)
extreme_signals += [(r_i, s_i) for s_i, signal_is_extreme in enumerate(np.abs(data[r_i]) > 3.5)
if signal_is_extreme]
# replace extreme signals with average of closest neighbors
for read_idx, signal_idx in extreme_signals:
if signal_idx == 0:
data[read_idx][signal_idx] = data[read_idx][signal_idx + 1]
elif signal_idx == (len(data[read_idx]) - 1):
data[read_idx][signal_idx] = data[read_idx][signal_idx - 1]
else:
data[read_idx][signal_idx] = (data[read_idx][signal_idx - 1] + data[read_idx][signal_idx + 1]) / 2
print(f'[Step 2] Done data normalization with batch {str(batch_idx)}')
return data
def process(reads, read_ids, batch_idx, bmodel, outpath, device, threshold):
# convert to torch tensors
reads = torch.tensor(reads).float()
with torch.no_grad():
data = reads.to(device)
outputs = bmodel(data)
sm = torch.nn.Softmax(dim=1)
scores = sm(outputs)
# if score of target class > threshold, classify as plasmid
# (opposite comparison because plasmids are labeled with zero)
binary_labels = (scores[:, PLASMID_LABEL] <= threshold).int().data.cpu().numpy()
labels = ['plasmid' if nr == PLASMID_LABEL else 'chr' for nr in binary_labels]
results = pd.DataFrame({'Read ID': read_ids, 'Predicted Label': labels})
results.to_csv(f'{outpath}/batch_{str(batch_idx)}.csv', index=False)
print(f'[Step 3] Done processing of batch {str(batch_idx)}')
del outputs
@click.command()
@click.option('--model', '-m', help='input path to trained model', type=click.Path(exists=True), required=True)
@click.option('--inpath', '-i', help='input path to fast5 data', type=click.Path(exists=True), required=True)
@click.option('--outpath', '-o', help='output path for results', type=click.Path(), required=True)
@click.option('--max_seq_len', '-max', default=4000, help='maximum number of raw signals (after cutoff) used per read')
@click.option('--batch_size', '-b', default=1000, help='number of reads per batch')
@click.option('--threshold', '-t', default=0.5, help='threshold for final classification decision')
def main(model, inpath, outpath, max_seq_len, batch_size, threshold):
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}\n')
if not os.path.exists(outpath):
os.makedirs(outpath)
# load trained model
bmodel = ResNet(Bottleneck, layers=[2, 2, 2, 2]).to(device)
bmodel.load_state_dict(torch.load(model, map_location=device))
print('[Step 0] Done loading model')
reads, read_ids = list(), list()
n_reads, batch_idx = 0, 0
files = glob.glob(f'{inpath}/*.fast5')
for f_idx, file in enumerate(files):
with get_fast5_file(file, mode='r') as f5:
reads_to_process = f5.get_read_ids()
for r_idx, read in enumerate(f5.get_reads()):
reads, read_ids = append_read(read, reads, read_ids)
n_reads += 1
if (n_reads == batch_size) or ((f_idx == len(files) - 1) and (r_idx == len(reads_to_process) - 1)):
print(f'[Step 1] Done loading data until batch {str(batch_idx)}')
reads = normalize(reads, batch_idx)
# pad with zeros until maximum sequence length
reads = [r + [0] * (max_seq_len - len(r)) for r in reads]
process(reads, read_ids, batch_idx, bmodel, outpath, device, threshold)
print(f'[Step 4] Done with batch {str(batch_idx)}\n')
del reads
reads = []
del read_ids
read_ids = []
batch_idx += 1
n_reads = 0
print(f'Classification of {batch_idx} batches finished.\n')
if __name__ == '__main__':
main()