diff --git a/pyproject.toml b/pyproject.toml index 424e2bc..068f1e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "requests>=2.26.0", "scipy", "bs4", + "matplotlib", + "git+https://github.com/MIC-DKFZ/HD-BET.git" ] dynamic = ["version"] diff --git a/quantconn/cli.py b/quantconn/cli.py index b3e166e..3990aaa 100644 --- a/quantconn/cli.py +++ b/quantconn/cli.py @@ -53,7 +53,11 @@ def process(db_path: Annotated[Path, typer.Option("--db-path", "-db", subjects = get_valid_subjects(db_path, subject) for sub in subjects: - t1_path = pjoin(db_path, "anat", f"{sub}_T1w.nii.gz") + t1_path = pjoin(db_path, sub, "anat", f"{sub}_T1w.nii.gz") + # t1_label_path = pjoin(db_path, sub, "anat", "aparc+aseg.nii.gz") + t1_label_path = pjoin(db_path, sub, "anat", "atlas_freesurfer_inT1space.nii.gz") + if not os.path.exists(t1_label_path): + t1_label_path = None for mod in ["A", "B"]: data_folder = pjoin(db_path, sub, mod) output_path = pjoin(destination, sub, mod) @@ -74,8 +78,8 @@ def process(db_path: Annotated[Path, typer.Option("--db-path", "-db", process_data(pjoin(data_folder, "dwi.nii.gz"), pjoin(data_folder, "dwi.bval"), pjoin(data_folder, "dwi.bvec"), - t1_path, - output_path) + t1_path, output_path, + t1_labels_fname=t1_label_path) print(":green_circle: [bold green]Success ! :love-you_gesture: [/bold green]") except Exception as e: print(f":boom: [bold red]Error while processing {sub} case {mod}[/bold red]") diff --git a/quantconn/process.py b/quantconn/process.py index 15bc3dd..56e4eaa 100644 --- a/quantconn/process.py +++ b/quantconn/process.py @@ -1,9 +1,17 @@ +import os from os.path import join as pjoin import numpy as np +import matplotlib.pyplot as plt from rich import print +from HD_BET.run import run_hd_bet + +import nibabel as nib +from nibabel.streamlines.trk import TrkFile +from dipy.align import affine_registration from dipy.align.streamlinear import whole_brain_slr +from dipy.align.reslice import reslice from dipy.core.gradients import gradient_table from dipy.io.gradients import read_bvals_bvecs from dipy.io.image import load_nifti, save_nifti @@ -20,7 +28,7 @@ from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import save_trk, load_trk from dipy.tracking.local_tracking import LocalTracking -from dipy.tracking.streamline import Streamlines +from dipy.tracking.streamline import Streamlines, transform_streamlines from dipy.segment.mask import median_otsu from dipy.segment.bundles import RecoBundles @@ -28,15 +36,24 @@ from quantconn.download import get_30_bundles_atlas_hcp842 -def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path): - data, affine, data_img = load_nifti(nifti_fname, return_img=True) - bvals, bvecs = read_bvals_bvecs(bval_fname, bvec_fname) - gtab = gradient_table(bvals, bvecs) +def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path, + t1_labels_fname=None): + dwi_data, dwi_affine, dwi_img = load_nifti(nifti_fname, return_img=True) + dwi_bvals, dwi_bvecs = read_bvals_bvecs(bval_fname, bvec_fname) + gtab = gradient_table(dwi_bvals, dwi_bvecs) + + print(':left_arrow_curving_right: Sampling/reslicing data') + vox_sz = dwi_img.header.get_zooms()[:3] + new_vox_size = [2.2, 2.2, 2.2] + resliced_data, resliced_affine = reslice(dwi_data, dwi_affine, vox_sz, + new_vox_size) + + save_nifti(pjoin(output_path, 'resliced_data.nii.gz'), + resliced_data, resliced_affine) print(':left_arrow_curving_right: Building mask') - maskdata, mask = median_otsu(data, median_radius=3, - vol_idx=np.where(gtab.b0s_mask)[0], - numpass=1, autocrop=True, dilate=2) + maskdata, mask = median_otsu( + resliced_data, vol_idx=np.where(gtab.b0s_mask)[0][:2]) print(':left_arrow_curving_right: Computing DTI metrics') tenmodel = TensorModel(gtab) @@ -47,27 +64,33 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path): FA = np.clip(FA, 0, 1) tensor_vals = lower_triangular(tenfit.quadratic_form) - ten_img = nifti1_symmat(tensor_vals, affine=affine) + ten_img = nifti1_symmat(tensor_vals, affine=resliced_affine) - save_nifti(pjoin(output_path, 'tensors.nii.gz'), ten_img.get_fdata(), - affine) - save_nifti(pjoin(output_path, 'fa.nii.gz'), FA.astype(np.float32), affine) + save_nifti(pjoin(output_path, 'tensors.nii.gz'), + ten_img.get_fdata().squeeze(), resliced_affine) + + save_nifti(pjoin(output_path, 'fa.nii.gz'), FA.astype(np.float32), + resliced_affine) GA = geodesic_anisotropy(tenfit.evals) - save_nifti(pjoin(output_path, 'ga.nii.gz'), GA.astype(np.float32), affine) + save_nifti(pjoin(output_path, 'ga.nii.gz'), GA.astype(np.float32), + resliced_affine) RGB = color_fa(FA, tenfit.evecs) save_nifti(pjoin(output_path, 'rgb.nii.gz'), np.array(255 * RGB, 'uint8'), - affine) + resliced_affine) MD = mean_diffusivity(tenfit.evals) - save_nifti(pjoin(output_path, 'md.nii.gz'), MD.astype(np.float32), affine) + save_nifti(pjoin(output_path, 'md.nii.gz'), MD.astype(np.float32), + resliced_affine) AD = axial_diffusivity(tenfit.evals) - save_nifti(pjoin(output_path, 'ad.nii.gz'), AD.astype(np.float32), affine) + save_nifti(pjoin(output_path, 'ad.nii.gz'), AD.astype(np.float32), + resliced_affine) RD = radial_diffusivity(tenfit.evals) - save_nifti(pjoin(output_path, 'rd.nii.gz'), RD.astype(np.float32), affine) + save_nifti(pjoin(output_path, 'rd.nii.gz'), RD.astype(np.float32), + resliced_affine) # TODO: Get White matter mask # download The-HCP-MMP1.0-atlas @@ -87,14 +110,20 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path): print(':left_arrow_curving_right: Whole Brain Tractography') stopping_criterion = ThresholdStoppingCriterion(csa_peaks.gfa, .25) - seeds = utils.seeds_from_mask(white_matter, affine, density=[2, 2, 2]) + seeds = utils.seeds_from_mask(white_matter, resliced_affine, + density=[2, 2, 2]) streamlines_generator = LocalTracking(csa_peaks, stopping_criterion, seeds, - affine=affine, step_size=.5) + affine=resliced_affine, step_size=.5) target_streamlines = Streamlines(streamlines_generator) - target_sft = StatefulTractogram(target_streamlines, data_img, Space.RASMM) - save_trk(target_sft, pjoin(output_path, "full_tractogram.trk")) + header = create_tractogram_header(TrkFile, resliced_affine, + maskdata.shape[:3], + new_vox_size, + ''.join(nib.aff2axcodes(resliced_affine))) + target_sft = StatefulTractogram(target_streamlines, header, Space.VOX) + save_trk(target_sft, pjoin(output_path, "full_tractogram.trk"), + bbox_valid_check=False) # Recobunble @@ -145,11 +174,91 @@ def process_data(nifti_fname, bval_fname, bvec_fname, t1_fname, output_path): save_trk(reco, pjoin(output_path, f"{bundle_name}_in_atlas_space.trk"), bbox_valid_check=False) reco = StatefulTractogram(target_streamlines[model_labels], - data_img, Space.RASMM) + header, Space.RASMM) save_trk(reco, pjoin(output_path, f"{bundle_name}_in_orig_space.trk"), bbox_valid_check=False) + ########################################################################## + # # Connectivity matrix + # + ########################################################################## + + if not t1_labels_fname: + print(':left_arrow_curving_right: No T1 labels file provided, Skipping connectivity matrix') + return + + print(':left_arrow_curving_right: Connectivity matrix: Loading data') + t1_data, t1_affine, t1_img = load_nifti(t1_fname, return_img=True) + label_data, label_affine, label_voxsize = load_nifti(t1_labels_fname, + return_voxsize=True) + + print(':left_arrow_curving_right: Connectivity matrix: T1 skullstripping') + use_hd_bet = True + t1_skullstrip_fname = pjoin( + output_path, + os.path.basename(t1_fname).replace('.nii.gz', 'skullstrip.nii.gz') + ) + if use_hd_bet: + run_hd_bet(t1_fname, t1_skullstrip_fname, mode='fast', device='cpu', + do_tta=False) + else: + from dipy.nn.evac import EVACPlus + evac = EVACPlus() + mask_volume = evac.predict(t1_data, t1_affine, + t1_img.header.get_zooms()[:3]) + masked_volume = mask_volume * t1_data + save_nifti(t1_skullstrip_fname, masked_volume, t1_affine) + + t1_noskull_data, t1_noskull_affine, t1_noskull_img = \ + load_nifti(t1_skullstrip_fname, return_img=True) + + print(':left_arrow_curving_right: Connectivity matrix: Registering DWI B0s to T1 /labels') + # Take one B0 instead of all of them or correct motion. + mean_b0 = np.mean(maskdata[..., gtab.b0s_mask], -1) + warped_b0, warped_b0_affine = affine_registration( + mean_b0, t1_noskull_data, moving_affine=resliced_affine, + static_affine=t1_noskull_affine) + + save_nifti(pjoin(output_path, "warped_b0.nii.gz"), warped_b0, + t1_noskull_affine) + + print(':left_arrow_curving_right: Connectivity matrix: Transforming Streamlines') + target_streamlines_in_t1 = transform_streamlines(target_streamlines, + warped_b0_affine, + in_place=True) + # filter small streamlines + + # if 1: + # print(nb.aff2axcodes(mapping)) + header = create_tractogram_header(TrkFile, warped_b0_affine, + maskdata.shape[:3], + new_vox_size, 'RAS') + + target_streamlines_in_t1_sft = StatefulTractogram(target_streamlines_in_t1, + header, Space.VOX) + save_trk(target_streamlines_in_t1_sft, + pjoin(output_path, "full_tractogram_in_t1.trk"), + bbox_valid_check=False) + + # import ipdb; ipdb.set_trace() + + # interactive = True + # if interactive: + # from dipy.viz.horizon.app import horizon + # horizon(tractograms=[resliced_target_sft],images=[(label_data, label_affine)], interactive=True, cluster=True) + + # print(':left_arrow_curving_right: Connectivity matrix') + # # Connectivity matrix + M, grouping = utils.connectivity_matrix( + target_streamlines_in_t1_sft, warped_b0_affine, + label_data.get_fdata().astype(np.uint8), return_mapping=True, + mapping_as_streamlines=True) + + # import ipdb; ipdb.set_trace() + # plt.imshow(np.log1p(M), interpolation='nearest') + # plt.savefig(pjoin(output_path, "connectivity.png")) +