-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathxtract_qc
executable file
·397 lines (336 loc) · 16.9 KB
/
xtract_qc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
#!/usr/bin/env fslpython
import sys,os,glob,subprocess,shutil
import argparse, textwrap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
import jinja2
import seaborn as sns
import plotly
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from datetime import datetime
from tqdm import tqdm
FSLDIR = os.getenv('FSLDIR')
FSLbin = os.path.join(FSLDIR, 'bin')
datadir = os.path.join(FSLDIR, 'data', 'xtract_data') # Location of xtract data
# some useful functions
def errchk(errflag):
if errflag:
print("Exit without doing anything..")
quit()
def imgtest(fname):
r = subprocess.run([f'{os.path.join(FSLbin, "imtest")} {fname}'], capture_output=True, text=True, shell=True)
return int(r.stdout)
class MyParser(argparse.ArgumentParser):
def error(self, message):
sys.stderr.write('error: %s\n' % message)
self.print_help()
sys.exit(2)
class Usage(Exception):
def __init__(self, msg):
self.msg = msg
splash = """
__ _______ ____ _ ____ _____ ___ ____
\ \/ /_ _| _ \ / \ / ___|_ _/ _ \ / ___|
\ / | | | |_) | / _ \| | | || | | | |
/ \ | | | _ < / ___ \ |___ | || |_| | |___
/_/\_\ |_| |_| \_\/_/ \_\____| |_| \__\_\\____|
"""
print(splash)
parser = MyParser(prog='XTRACT QC',
description='xtract_qc: quality control at the group-level',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=textwrap.dedent('''Example usage:
xtract_qc -subject_list sublist.txt -out /data/xtract_qc
'''))
required = parser.add_argument_group('Required arguments')
optional = parser.add_argument_group('Optional arguments')
# Compulsory arguments:
required.add_argument("-subject_list", metavar='<txt>', help="Text file containing line separated subject-wise paths to XTRACT folders", required=True)
required.add_argument("-out", metavar='<folder>', help="Path to output folder", required=True)
# Optional arguments:
optional.add_argument("-tract_list", metavar='<list>', help="Comma separated list of tracts to include (default = all found under -xtract <folder>)")
optional.add_argument("-thr", metavar='<float>', default=0.001, type=float, help="Threshold applied to XTRACT tracts for volume calculation (default = 0.001).")
optional.add_argument("-n_std", metavar='<float>', default=2, type=float, help="The number of standard deviations (either side of mean) to allow before being flagged as an outlier (default = 2).")
optional.add_argument("-use_prior", action="store_true", default=False, help="If already run xtract_qc, use previously created metrics (default = create new metrics and overwrite)")
argsa = parser.parse_args()
subject_list = argsa.subject_list
out = argsa.out
thr = argsa.thr
n_std = argsa.n_std
# Step 0 : Prepare
# Check compulsory arguments
errflag=0
if os.path.isfile(subject_list) is False:
print(f'Subject list file {subject_list} not found')
errflag=1
if os.path.isdir(out):
print('Warning: Output folder already exists. Some of the files may be overwritten')
# check optional arguments
if thr == 0:
print('Tract threshold (-thr) has been set to 0. Must be greater than 0. Default is 0.001')
errflag=1
if n_std == 0:
print('Number of standard deviations (-n_std) has been set to 0. Must be greater than 0. Default is 2')
errflag=1
errchk(errflag)
os.makedirs(out, exist_ok=True)
# get subject information
subject_list = np.loadtxt(subject_list, dtype='str')
# get tract information
if argsa.tract_list is not None:
tracts = argsa.tract_list.split(',')
else:
print(f'Getting tract list from default: {os.path.join(datadir, "HUMAN", "structureList")}')
tracts = []
with open(os.path.join(datadir, 'HUMAN', 'structureList')) as f:
for line in f:
struct = line.strip()
if struct and not struct.startswith("#"):
tracts.append(struct.split()[0])
tracts.sort()
# dimensions of cohort
nsubs = len(subject_list)
ntracts = len(tracts)
print(f'{nsubs} XTRACT directories in list')
print(f'{ntracts} tracts')
# function to get metrics
def get_metrics(sub_path, sub_out, tracts, thr):
ntracts = len(tracts)
filecheck, tract_waytotal, tract_volume = np.ones((ntracts,)), np.ones((ntracts,)), np.ones((ntracts,))
for nt, tract in enumerate(tracts):
tfile = os.path.join(sub_path, 'tracts', tract, 'densityNorm.nii.gz')
wayfile = os.path.join(sub_path, 'tracts', tract, 'waytotal')
# filecheck
if imgtest(tfile) is False:
filecheck[nt] = 0
# if it does exist, get volume
else:
# if so, get tract volume
tvol = subprocess.run([os.path.join(FSLbin, 'fslstats'), tfile, '-l', str(thr), '-V'], capture_output=True, text=True)
tvol = str(tvol.stdout).split(' ')[1]
tract_volume[nt] = float(tvol)
# get waytotal
if os.path.isfile(wayfile) is False:
tract_waytotal[nt] = np.nan
else:
tract_waytotal[nt] = np.loadtxt(wayfile, dtype=int)
# save check to subject directory
np.savetxt(os.path.join(sub_out, 'filecheck'), filecheck, fmt='%i')
np.savetxt(os.path.join(sub_out, 'waytotal'), tract_waytotal, fmt='%i')
np.savetxt(os.path.join(sub_out, 'volume'), tract_volume, fmt='%i')
return filecheck, tract_waytotal, tract_volume
# filecheck, waytotal and tract volume arrays
subcheck, filecheck_all, tract_waytotal_all, tract_volume_all = np.ones((nsubs,)), pd.DataFrame(index=np.arange(0, nsubs), columns=tracts), pd.DataFrame(index=np.arange(0, nsubs), columns=tracts), pd.DataFrame(index=np.arange(0, nsubs), columns=tracts)
# loop through subjects and tracts
# if using prior run, check if output files already exist
# if it does, do nothing and move on
# if it doesn't exist, get metrics
for ns, sub_path in enumerate(tqdm(subject_list)):
sub_out = os.path.join(sub_path, 'xtract_qc')
if os.path.isdir(sub_path) is False:
subcheck[ns] = 1
print(f'{sub_path} missing')
else:
if argsa.use_prior and os.path.isfile(os.path.join(sub_out, 'filecheck')) and os.path.isfile(os.path.join(sub_out, 'waytotal')) and os.path.isfile(os.path.join(sub_out, 'volume')):
filecheck_all.loc[ns, :] = np.loadtxt(os.path.join(sub_out, 'filecheck'))
tract_waytotal_all.loc[ns, :] = np.loadtxt(os.path.join(sub_out, 'waytotal'))
tract_volume_all.loc[ns, :] = np.loadtxt(os.path.join(sub_out, 'volume'))
else:
shutil.rmtree(sub_out, ignore_errors=True)
os.makedirs(sub_out, exist_ok=True)
filecheck_all.loc[ns, :], tract_waytotal_all.loc[ns, :], tract_volume_all.loc[ns, :] = get_metrics(sub_path, sub_out, tracts, thr)
if np.sum(subcheck) < nsubs:
print(f'Could not find XTRACT folders for {int(np.sum(subcheck) - nsubs)} subject(s)')
missing_list = os.path.join(out, 'missing_subjects.txt')
np.savetxt(missing_list, subject_list[np.argwhere(subcheck == 0)], fmt="%s")
print(f'Excluding from QC\nSee {missing_list} for list')
print('Generating report...')
##########################################
# now report subject's missing results and give percent success for each tract and flag any subjects with all missing
# working with filecheck metric
sumcheck = np.sum(filecheck_all.values, axis=1)
sub_missing = np.argwhere(sumcheck < ntracts) # subjects missing data
sub_missing_txt = []
sub_missing_all_txt = []
for idx in sub_missing:
sub = subject_list[idx]
n_missing = int(ntracts - sumcheck[idx])
if n_missing < ntracts:
filecheck = np.loadtxt(os.path.join(subject_list[idx], 'xtract_qc', 'filecheck'))
sub_missing_txt.append(f'{sub[0]} missing {n_missing} tracts: {[tracts[i[0]] for i in np.argwhere(filecheck == 0)]}')
elif n_missing == ntracts:
sub_missing_all_txt.append(f'{sub[0]} missing all tracts')
# get overall success
sumcheck_all = np.round((np.sum(sumcheck) / (ntracts * nsubs)) * 100, 3) # now overall percent success across subjects and tracts
sumcheck_all_txt = f'{sumcheck_all}% ({int(np.sum(sumcheck))} of {int(ntracts * nsubs)})'
tract_success = (np.sum(filecheck_all.values, axis=0)/nsubs) * 100 # percent of subjects successfully completed per tract
# report simple filecheck metric
ps_inc = 0
if sumcheck_all < 100:
ps_inc = 1
tract_success = pd.DataFrame({'tracts': tracts, 'percent_successful': tract_success})
fig = px.line(tract_success, x=tracts, y="percent_successful",
labels=dict(tracts="Tract", percent_successful="Percent successful (%)", )
)
fig.add_shape(type="line", x0=0, y0=100, x1=ntracts, y1=100, line=dict(color="orange", dash="dash"))
fig.update_xaxes(title_text="Tract")
fig.write_html(os.path.join(out, 'percent_success_fig.html'))
##########################################
# make metric distribution plots
def myround(x, base=5):
y = base * round(x/base)
if y < x:
y = y + base
return y
def plot_tractwise_metric(metric, metric_name, n_std):
k = 4
m = int(myround(ntracts)/k)
fig, ax = plt.subplots(m, k, figsize=(16, m*1.75))
n = 0
for j in range(m):
for i in range(k):
if n < ntracts:
# plot distribution
x = metric.values[:,n]
sns.kdeplot(x, ax=ax[j,i], fill=False, color='crimson')
kdeline = ax[j,i].lines[0]
xs = kdeline.get_xdata()
ys = kdeline.get_ydata()
middle = x.mean()
median = np.median(x)
sdev = x.std()
left = middle - sdev
right = middle + sdev
ax[j,i].vlines(middle, 0, np.interp(middle, xs, ys), color='crimson', ls=':')
ax[j,i].vlines(median, 0, np.interp(median, xs, ys), color='crimson', ls='--')
ax[j,i].fill_between(xs, 0, ys, facecolor='crimson', alpha=0.5)
ax[j,i].fill_between(xs, 0, ys, where=(left <= xs) & (xs <= right), interpolate=True, facecolor='crimson', alpha=0.5)
# add markers for n_std standard deviations
left = middle - n_std*sdev
right = middle + n_std*sdev
ax[j,i].axvline(left, color='k', ls='--')
ax[j,i].axvline(right, color='k', ls='--')
# set axis options
ax[j,i].set_xlim(left=0)
ax[j,i].set_title(tracts[n], size=18, y=1.05)
ax[j,i].ticklabel_format(style='sci', scilimits=(0, 0))
else:
ax[j,i].axis('off')
n += 1
fig.tight_layout()
fig.savefig(os.path.join(out, f'{metric_name}_fig.jpg'))
plt.close()
metric = tract_waytotal_all
metric_name = 'waytotal'
def plot_tractwise_metric(metric, metric_name, n_std):
k = 4 # Number of subplots per row
m = (ntracts + k - 1) // k # Number of rows
metric = pd.melt(metric, var_name='tract', value_name=metric_name)
def plot_subplot(data, color, **kwargs):
sns.kdeplot(data, fill=False, color=color)
mean_val = np.mean(data)
median_val = np.median(data)
std_val = np.std(data)
plt.axvline(x=mean_val, color=color, ls=':')
plt.axvline(x=median_val, color=color, ls='--')
plt.axvspan(mean_val - n_std * std_val, mean_val + n_std * std_val, color=color, alpha=0.5)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
g = sns.FacetGrid(metric, col="tract", col_wrap=k, height=4, aspect=1.75, sharey=False, sharex=False)
foo = g.map(plot_subplot, metric_name, color='crimson')
foo = g.set_titles("{col_name}")
foo = plt.subplots_adjust(top=0.9)
foo = g.fig.suptitle(f'Metric: {metric_name}', fontsize=20)
g.savefig(os.path.join(out, f'{metric_name}_fig.jpg'))
plt.close()
plot_tractwise_metric(tract_waytotal_all, 'waytotal', n_std)
plot_tractwise_metric(tract_volume_all, 'volume', n_std)
##########################################
# list of subjects with waytotal=0, but have tract files
# these subjects are likely to have failed due to bad registration
sub_ind_zerowt = np.argwhere((filecheck_all.values == 1) & (tract_waytotal_all.values == 0))[:,0]
sub_ind_zerowt = np.unique(sub_ind_zerowt)
sub_zerowt_txt = []
for idx in sub_ind_zerowt:
sub = subject_list[idx]
nempty = np.where(tract_waytotal_all.values[idx,:] == 0)[0].shape[0]
sub_zerowt_txt.append(f'{sub} has empty tracts files for {nempty}: {[tracts[i[0]] for i in np.argwhere(tract_waytotal_all.values[idx,:] == 0)]}')
sub_zerowt_txt.append(f'Total: {len(sub_ind_zerowt)} subjects with empty tracts')
# find outliers based on waytotal and volume
# add here a string of subjects where their volume/waytotal is >2*std away from mean
def get_outliers(metric):
outliers = np.abs(metric-metric.mean()) >= (n_std*metric.std())
outliers = np.logical_and(outliers, (metric > 0)) # subject > n_std away from mean but not 0
outliers_summary = np.sum(outliers) # tract-wise outliers
extreme_outliers = np.sum(outliers, axis=1) # now getting subject with more than 0.5*ntracts
extreme_outliers = extreme_outliers >= np.rint(0.5*ntracts)
return outliers, outliers_summary, extreme_outliers
# make outlier plots
fig = make_subplots(rows=2, cols=1,
subplot_titles=('Waytotal', 'Volume'), shared_xaxes=True,
specs=[[{"secondary_y": True}], [{"secondary_y": True}]])
# waytotal outlier plot
wt_outliers, outliers_summary, wt_extreme_outliers = get_outliers(tract_waytotal_all)
outliers_summary = pd.DataFrame({'tracts': tracts, 'no_outliers':outliers_summary.values, 'perc':np.round(100*outliers_summary.values/nsubs,2)})
fig.add_trace(go.Scatter(x=outliers_summary['tracts'], y=outliers_summary['no_outliers'], name="Number", marker=dict(color='darkslategray')), row=1, col=1, secondary_y=False);
fig.add_trace(go.Scatter(x=outliers_summary['tracts'], y=outliers_summary['perc'], name="Percent", marker=dict(color='darkslategray')), row=1, col=1, secondary_y=True);
fig.update_layout(showlegend=False, hovermode="x unified");
fig.update_yaxes(title_text="# potential outliers", secondary_y=False);
fig.update_yaxes(title_text="% potential outliers", secondary_y=True);
# volume outlier plot
vol_outliers, outliers_summary, vol_extreme_outliers = get_outliers(tract_volume_all)
outliers_summary = pd.DataFrame({'tracts': tracts, 'no_outliers':outliers_summary.values, 'perc':np.round(100*outliers_summary.values/nsubs,2)})
fig.add_trace(go.Scatter(x=outliers_summary['tracts'], y=outliers_summary['no_outliers'], name="Number", marker=dict(color='darkslategray')), row=2, col=1, secondary_y=False);
fig.add_trace(go.Scatter(x=outliers_summary['tracts'], y=outliers_summary['perc'], name="Percent", marker=dict(color='darkslategray')), row=2, col=1, secondary_y=True);
fig.update_layout(showlegend=False, hovermode="x unified")
fig.update_yaxes(title_text="# potential outliers", secondary_y=False)
fig.update_yaxes(title_text="% potential outliers", secondary_y=True)
fig.update_xaxes(showticklabels=False) # hide all the xticks
fig.update_xaxes(title_text="Tract", showticklabels=True, row=2, col=1, tickmode='linear')
fig.write_html(os.path.join(out, 'outlier_fig.html'))
# build outlier report txt/csv
# extreme outliers (those with more than 0.5*ntracts as outliers)
grot = np.argwhere(vol_extreme_outliers.values == 1)
grot = np.append(grot, np.argwhere(wt_extreme_outliers.values == 1))
grot = np.unique(grot)
sub_extreme_outliers_txt = []
if grot.shape[0] > 0:
f = open(os.path.join(out, 'extreme_outliers.txt'), 'w')
for idx in grot:
sub_extreme_outliers_txt.append(f'{subject_list[idx]}\n')
f.write(f'{subject_list[idx]}\n')
f.close()
# if subject is n_std away from mean for waytotal/volume
# save outliers as .csv
sub_outliers_path = os.path.join(out, 'tract_outliers.csv')
wt_outliers.set_index(subject_list, inplace=True)
wt_outliers = wt_outliers.astype("int")
wt_outliers.to_csv(sub_outliers_path)
n_outliers = np.sum(np.sum(wt_outliers))
########################################## Create report
env = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath=datadir))
template = env.get_template('template_qc_report.html')
filename = os.path.join(out, 'qc_report.html')
with open(filename, 'w') as fh:
fh.write(template.render(
time = datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
ntracts = ntracts,
nsubs = nsubs,
n_std = n_std,
sumcheck_all_txt = sumcheck_all_txt,
sub_missing_txt = sub_missing_txt,
sub_missing_all_txt = sub_missing_all_txt,
sub_zerowt_txt = sub_zerowt_txt,
sub_outliers_path = sub_outliers_path,
n_outliers = n_outliers,
sub_extreme_outliers_txt = sub_extreme_outliers_txt,
ps_inc = ps_inc
))
print(f'Report generated\nSee {os.path.join(out, "qc_report.html")}')
quit()