Skip to content

Commit

Permalink
resliced improved
Browse files Browse the repository at this point in the history
  • Loading branch information
skoudoro committed Jul 28, 2023
1 parent 9c8e6ba commit 54e97f3
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 25 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ dependencies = [
"requests>=2.26.0",
"scipy",
"bs4",
"matplotlib",
"git+https://github.com/MIC-DKFZ/HD-BET.git"
]
dynamic = ["version"]

Expand Down
10 changes: 7 additions & 3 deletions quantconn/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def process(db_path: Annotated[Path, typer.Option("--db-path", "-db",
subjects = get_valid_subjects(db_path, subject)

for sub in subjects:
t1_path = pjoin(db_path, "anat", f"{sub}_T1w.nii.gz")
t1_path = pjoin(db_path, sub, "anat", f"{sub}_T1w.nii.gz")
# t1_label_path = pjoin(db_path, sub, "anat", "aparc+aseg.nii.gz")
t1_label_path = pjoin(db_path, sub, "anat", "atlas_freesurfer_inT1space.nii.gz")
if not os.path.exists(t1_label_path):
t1_label_path = None
for mod in ["A", "B"]:
data_folder = pjoin(db_path, sub, mod)
output_path = pjoin(destination, sub, mod)
Expand All @@ -74,8 +78,8 @@ def process(db_path: Annotated[Path, typer.Option("--db-path", "-db",
process_data(pjoin(data_folder, "dwi.nii.gz"),
pjoin(data_folder, "dwi.bval"),
pjoin(data_folder, "dwi.bvec"),
t1_path,
output_path)
t1_path, output_path,
t1_labels_fname=t1_label_path)
print(":green_circle: [bold green]Success ! :love-you_gesture: [/bold green]")
except Exception as e:
print(f":boom: [bold red]Error while processing {sub} case {mod}[/bold red]")
Expand Down
153 changes: 131 additions & 22 deletions quantconn/process.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import os
from os.path import join as pjoin

import numpy as np
import matplotlib.pyplot as plt
from rich import print

from HD_BET.run import run_hd_bet

import nibabel as nib
from nibabel.streamlines.trk import TrkFile
from dipy.align import affine_registration
from dipy.align.streamlinear import whole_brain_slr
from dipy.align.reslice import reslice
from dipy.core.gradients import gradient_table
from dipy.io.gradients import read_bvals_bvecs
from dipy.io.image import load_nifti, save_nifti
Expand All @@ -20,23 +28,32 @@
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_trk, load_trk
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.streamline import Streamlines
from dipy.tracking.streamline import Streamlines, transform_streamlines
from dipy.segment.mask import median_otsu
from dipy.segment.bundles import RecoBundles


from quantconn.download import get_30_bundles_atlas_hcp842


def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path):
data, affine, data_img = load_nifti(nifti_fname, return_img=True)
bvals, bvecs = read_bvals_bvecs(bval_fname, bvec_fname)
gtab = gradient_table(bvals, bvecs)
def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path,
t1_labels_fname=None):
dwi_data, dwi_affine, dwi_img = load_nifti(nifti_fname, return_img=True)
dwi_bvals, dwi_bvecs = read_bvals_bvecs(bval_fname, bvec_fname)
gtab = gradient_table(dwi_bvals, dwi_bvecs)

print(':left_arrow_curving_right: Sampling/reslicing data')
vox_sz = dwi_img.header.get_zooms()[:3]
new_vox_size = [2.2, 2.2, 2.2]
resliced_data, resliced_affine = reslice(dwi_data, dwi_affine, vox_sz,
new_vox_size)

save_nifti(pjoin(output_path, 'resliced_data.nii.gz'),
resliced_data, resliced_affine)

print(':left_arrow_curving_right: Building mask')
maskdata, mask = median_otsu(data, median_radius=3,
vol_idx=np.where(gtab.b0s_mask)[0],
numpass=1, autocrop=True, dilate=2)
maskdata, mask = median_otsu(
resliced_data, vol_idx=np.where(gtab.b0s_mask)[0][:2])

print(':left_arrow_curving_right: Computing DTI metrics')
tenmodel = TensorModel(gtab)
Expand All @@ -47,27 +64,33 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path):
FA = np.clip(FA, 0, 1)

tensor_vals = lower_triangular(tenfit.quadratic_form)
ten_img = nifti1_symmat(tensor_vals, affine=affine)
ten_img = nifti1_symmat(tensor_vals, affine=resliced_affine)

save_nifti(pjoin(output_path, 'tensors.nii.gz'), ten_img.get_fdata(),
affine)
save_nifti(pjoin(output_path, 'fa.nii.gz'), FA.astype(np.float32), affine)
save_nifti(pjoin(output_path, 'tensors.nii.gz'),
ten_img.get_fdata().squeeze(), resliced_affine)

save_nifti(pjoin(output_path, 'fa.nii.gz'), FA.astype(np.float32),
resliced_affine)

GA = geodesic_anisotropy(tenfit.evals)
save_nifti(pjoin(output_path, 'ga.nii.gz'), GA.astype(np.float32), affine)
save_nifti(pjoin(output_path, 'ga.nii.gz'), GA.astype(np.float32),
resliced_affine)

RGB = color_fa(FA, tenfit.evecs)
save_nifti(pjoin(output_path, 'rgb.nii.gz'), np.array(255 * RGB, 'uint8'),
affine)
resliced_affine)

MD = mean_diffusivity(tenfit.evals)
save_nifti(pjoin(output_path, 'md.nii.gz'), MD.astype(np.float32), affine)
save_nifti(pjoin(output_path, 'md.nii.gz'), MD.astype(np.float32),
resliced_affine)

AD = axial_diffusivity(tenfit.evals)
save_nifti(pjoin(output_path, 'ad.nii.gz'), AD.astype(np.float32), affine)
save_nifti(pjoin(output_path, 'ad.nii.gz'), AD.astype(np.float32),
resliced_affine)

RD = radial_diffusivity(tenfit.evals)
save_nifti(pjoin(output_path, 'rd.nii.gz'), RD.astype(np.float32), affine)
save_nifti(pjoin(output_path, 'rd.nii.gz'), RD.astype(np.float32),
resliced_affine)

# TODO: Get White matter mask
# download The-HCP-MMP1.0-atlas
Expand All @@ -87,14 +110,20 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path):
print(':left_arrow_curving_right: Whole Brain Tractography')
stopping_criterion = ThresholdStoppingCriterion(csa_peaks.gfa, .25)

seeds = utils.seeds_from_mask(white_matter, affine, density=[2, 2, 2])
seeds = utils.seeds_from_mask(white_matter, resliced_affine,
density=[2, 2, 2])

streamlines_generator = LocalTracking(csa_peaks, stopping_criterion, seeds,
affine=affine, step_size=.5)
affine=resliced_affine, step_size=.5)
target_streamlines = Streamlines(streamlines_generator)

target_sft = StatefulTractogram(target_streamlines, data_img, Space.RASMM)
save_trk(target_sft, pjoin(output_path, "full_tractogram.trk"))
header = create_tractogram_header(TrkFile, resliced_affine,
maskdata.shape[:3],
new_vox_size,
''.join(nib.aff2axcodes(resliced_affine)))
target_sft = StatefulTractogram(target_streamlines, header, Space.VOX)
save_trk(target_sft, pjoin(output_path, "full_tractogram.trk"),
bbox_valid_check=False)

# Recobunble

Expand Down Expand Up @@ -145,11 +174,91 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path):
save_trk(reco, pjoin(output_path, f"{bundle_name}_in_atlas_space.trk"),
bbox_valid_check=False)
reco = StatefulTractogram(target_streamlines[model_labels],
data_img, Space.RASMM)
header, Space.RASMM)
save_trk(reco, pjoin(output_path, f"{bundle_name}_in_orig_space.trk"),
bbox_valid_check=False)

##########################################################################
#
# Connectivity matrix
#
##########################################################################

if not t1_labels_fname:
print(':left_arrow_curving_right: No T1 labels file provided, Skipping connectivity matrix')
return

print(':left_arrow_curving_right: Connectivity matrix: Loading data')
t1_data, t1_affine, t1_img = load_nifti(t1_fname, return_img=True)
label_data, label_affine, label_voxsize = load_nifti(t1_labels_fname,
return_voxsize=True)

print(':left_arrow_curving_right: Connectivity matrix: T1 skullstripping')
use_hd_bet = True
t1_skullstrip_fname = pjoin(
output_path,
os.path.basename(t1_fname).replace('.nii.gz', 'skullstrip.nii.gz')
)
if use_hd_bet:
run_hd_bet(t1_fname, t1_skullstrip_fname, mode='fast', device='cpu',
do_tta=False)
else:
from dipy.nn.evac import EVACPlus
evac = EVACPlus()
mask_volume = evac.predict(t1_data, t1_affine,
t1_img.header.get_zooms()[:3])
masked_volume = mask_volume * t1_data
save_nifti(t1_skullstrip_fname, masked_volume, t1_affine)

t1_noskull_data, t1_noskull_affine, t1_noskull_img = \
load_nifti(t1_skullstrip_fname, return_img=True)

print(':left_arrow_curving_right: Connectivity matrix: Registering DWI B0s to T1 /labels')
# Take one B0 instead of all of them or correct motion.
mean_b0 = np.mean(maskdata[..., gtab.b0s_mask], -1)
warped_b0, warped_b0_affine = affine_registration(
mean_b0, t1_noskull_data, moving_affine=resliced_affine,
static_affine=t1_noskull_affine)

save_nifti(pjoin(output_path, "warped_b0.nii.gz"), warped_b0,
t1_noskull_affine)

print(':left_arrow_curving_right: Connectivity matrix: Transforming Streamlines')
target_streamlines_in_t1 = transform_streamlines(target_streamlines,
warped_b0_affine,
in_place=True)
# filter small streamlines

# if 1:
# print(nb.aff2axcodes(mapping))
header = create_tractogram_header(TrkFile, warped_b0_affine,
maskdata.shape[:3],
new_vox_size, 'RAS')

target_streamlines_in_t1_sft = StatefulTractogram(target_streamlines_in_t1,
header, Space.VOX)
save_trk(target_streamlines_in_t1_sft,
pjoin(output_path, "full_tractogram_in_t1.trk"),
bbox_valid_check=False)

# import ipdb; ipdb.set_trace()

# interactive = True
# if interactive:
# from dipy.viz.horizon.app import horizon
# horizon(tractograms=[resliced_target_sft],images=[(label_data, label_affine)], interactive=True, cluster=True)

# print(':left_arrow_curving_right: Connectivity matrix')
# # Connectivity matrix
M, grouping = utils.connectivity_matrix(
target_streamlines_in_t1_sft, warped_b0_affine,
label_data.get_fdata().astype(np.uint8), return_mapping=True,
mapping_as_streamlines=True)

# import ipdb; ipdb.set_trace()
# plt.imshow(np.log1p(M), interpolation='nearest')
# plt.savefig(pjoin(output_path, "connectivity.png"))




Expand Down

0 comments on commit 54e97f3

Please sign in to comment.