diff --git a/data/mni_icbm152_t1_tal_nlin_asym_09c_masked_2_5.nii.gz b/data/mni_icbm152_t1_tal_nlin_asym_09c_masked_2_5.nii.gz new file mode 100755 index 000000000..1ef0fc82a Binary files /dev/null and b/data/mni_icbm152_t1_tal_nlin_asym_09c_masked_2_5.nii.gz differ diff --git a/scilpy/image/volume_operations.py b/scilpy/image/volume_operations.py index 2643e14e7..f8529046d 100644 --- a/scilpy/image/volume_operations.py +++ b/scilpy/image/volume_operations.py @@ -194,17 +194,55 @@ def transform_dwi(reg_obj, static, dwi, interpolation='linear'): def register_image(static, static_grid2world, moving, moving_grid2world, - transformation_type='affine', dwi=None): + transformation_type='affine', dwi=None, fine=False): + """ + Register a moving image to a static image using either rigid or affine + transformations. If a DWI (4D) is provided, it applies the transformation + to each volume. + + Parameters + ---------- + static : ndarray + The static image volume to which the moving image will be registered. + static_grid2world : ndarray + The grid-to-world (vox2ras) transformation associated with the static + image. + moving : ndarray + The moving image volume that needs to be registered to the static image. + moving_grid2world : ndarray + The grid-to-world (vox2ras) transformation associated with the moving + image. + transformation_type : str, optional + The type of transformation ('rigid' or 'affine'). Default is 'affine'. + dwi : ndarray, optional + Diffusion-weighted imaging data (if applicable). Default is None. + fine : bool, optional + Whether to use fine or coarse settings for the registration. + Default is False. + + Raises + ------ + ValueError + If the transformation_type is neither 'rigid' nor 'affine'. + + Returns + ------- + ndarray or tuple + If `dwi` is None, returns transformed moving image and transformation + matrix. + If `dwi` is not None, returns transformed DWI and transformation matrix. + """ + if transformation_type not in ['rigid', 'affine']: raise ValueError('Transformation type not available in Dipy') # Set all parameters for registration - nbins = 32 + nbins = 64 if fine else 32 params0 = None sampling_prop = None - level_iters = [50, 25, 5] - sigmas = [8.0, 4.0, 2.0] - factors = [8, 4, 2] + level_iters = [250, 100, 50, 25] if fine else [50, 25, 5] + sigmas = [8.0, 4.0, 2.0, 1.0] if fine else [8.0, 4.0, 2.0] + factors = [8, 4, 2, 1.0] if fine else [8, 4, 2] metric = MutualInformationMetric(nbins, sampling_prop) reg_obj = AffineRegistration(metric=metric, level_iters=level_iters, sigmas=sigmas, factors=factors, verbosity=0) diff --git a/scilpy/io/fetcher.py b/scilpy/io/fetcher.py index fd5b5218b..8113e609d 100644 --- a/scilpy/io/fetcher.py +++ b/scilpy/io/fetcher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import inspect import logging import hashlib import os @@ -131,3 +132,20 @@ def fetch_data(files_dict, keys=None): else: # toDo. Verify that data on disk is the right one. logging.warning("Not fetching data; already on disk.") + + +def get_synb0_template_path(): + """ + Return MNI 2.5mm template in scilpy repository + Returns + ------- + path: str + Template path + """ + import scilpy # ToDo. Is this the only way? + module_path = inspect.getfile(scilpy) + module_path = os.path.dirname(os.path.dirname(module_path)) + + path = os.path.join(module_path, 'data/', + 'mni_icbm152_t1_tal_nlin_asym_09c_masked_2_5.nii.gz') + return path diff --git a/scilpy/preprocessing/distortion_correction.py b/scilpy/preprocessing/distortion_correction.py index 0996c6a3f..518d4c032 100644 --- a/scilpy/preprocessing/distortion_correction.py +++ b/scilpy/preprocessing/distortion_correction.py @@ -1,9 +1,12 @@ # -*- coding: utf-8 -*- +import logging + import numpy as np -def create_acqparams(readout, encoding_direction, nb_b0s=1, nb_rev_b0s=1): +def create_acqparams(readout, encoding_direction, synb0=False, + nb_b0s=1, nb_rev_b0s=1): """ Create acqparams for Topup and Eddy @@ -23,13 +26,18 @@ def create_acqparams(readout, encoding_direction, nb_b0s=1, nb_rev_b0s=1): acqparams: np.array acqparams """ + if synb0: + logging.warning('Using SyNb0, untested feature. Be careful.') + acqparams = np.zeros((nb_b0s + nb_rev_b0s, 4)) acqparams[:, 3] = readout enum_direction = {'x': 0, 'y': 1, 'z': 2} acqparams[0:nb_b0s, enum_direction[encoding_direction]] = 1 if nb_rev_b0s > 0: - acqparams[nb_b0s:, enum_direction[encoding_direction]] = -1 + val = -1 if not synb0 else 1 + acqparams[nb_b0s:, enum_direction[encoding_direction]] = val + acqparams[nb_b0s:, 3] = readout if not synb0 else 0 return acqparams diff --git a/scilpy/tractanalysis/afd_along_streamlines.py b/scilpy/tractanalysis/afd_along_streamlines.py index 9b9b32efa..06c019155 100644 --- a/scilpy/tractanalysis/afd_along_streamlines.py +++ b/scilpy/tractanalysis/afd_along_streamlines.py @@ -33,7 +33,6 @@ def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting): rdAFD map (weighted if length_weighting) """ - afd_sum, rd_sum, weights = \ afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, length_weighting) @@ -112,7 +111,8 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, normalization_weights = np.ones_like(seg_lengths) if length_weighting: - normalization_weights = seg_lengths / np.linalg.norm(fodf.header.get_zooms()[:3]) + normalization_weights = seg_lengths / \ + np.linalg.norm(fodf.header.get_zooms()[:3]) for vox_idx, closest_vertex_index, norm_weight in zip(vox_indices, closest_vertex_indices, @@ -130,5 +130,4 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, weight_map[vox_idx] += norm_weight rd_sum_map[rd_sum_map < 0.] = 0. - return afd_sum_map, rd_sum_map, weight_map diff --git a/scripts/scil_dwi_prepare_topup_command.py b/scripts/scil_dwi_prepare_topup_command.py index b1ce316a0..b39a02776 100755 --- a/scripts/scil_dwi_prepare_topup_command.py +++ b/scripts/scil_dwi_prepare_topup_command.py @@ -34,6 +34,9 @@ def _build_arg_parser(): p.add_argument('--config', default='b02b0.cnf', help='Topup config file [%(default)s].') + p.add_argument('--synb0', action='store_true', + help='If set, will use SyNb0 custom acqparams file.') + p.add_argument('--encoding_direction', default='y', choices=['x', 'y', 'z'], help='Acquisition direction of the forward b0 ' @@ -119,8 +122,8 @@ def main(): fused_b0s_path = os.path.join(args.out_directory, args.out_b0s) nib.save(nib.Nifti1Image(fused_b0s, b0_img.affine), fused_b0s_path) - acqparams = create_acqparams( - args.readout, args.encoding_direction, b0.shape[-1], rev_b0.shape[-1]) + acqparams = create_acqparams(args.readout, args.encoding_direction, + args.synb0, b0.shape[-1], rev_b0.shape[-1]) if not os.path.exists(args.out_directory): os.makedirs(args.out_directory) diff --git a/scripts/scil_volume_b0_synthesis.py b/scripts/scil_volume_b0_synthesis.py new file mode 100644 index 000000000..99784173e --- /dev/null +++ b/scripts/scil_volume_b0_synthesis.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Wrapper for SyNb0 available in Dipy, to run it on a single subject. +Requires Skull-Strip b0 and t1w images as input, the script will normalize the +t1w's WM to 110, co-register both images, then register it to the appropriate +template, run SyNb0 and then transform the result back to the original space. + +This script must be used carefully, as it is not meant to be used in an +environment with the following dependencies already installed (not default +in Scilpy): +- tensorflow-addons +- tensorrt +- tensorflow +""" + + +import argparse +import logging +import os +import sys +import warnings + +# Disable tensorflow warnings +with warnings.catch_warnings(): + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + warnings.simplefilter("ignore") + from dipy.nn.synb0 import Synb0 + +from dipy.align.imaffine import AffineMap +from dipy.segment.tissue import TissueClassifierHMRF +import nibabel as nib +import numpy as np +from scipy.ndimage import gaussian_filter + +from scilpy.io.fetcher import get_synb0_template_path +from scilpy.io.utils import (add_overwrite_arg, + add_verbose_arg, + assert_inputs_exist, + assert_outputs_exist) +from scilpy.image.volume_operations import register_image + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_b0', + help='Input b0 image.') + p.add_argument('in_b0_mask', + help='Input b0 mask.') + p.add_argument('in_t1', + help='Input t1w image.') + p.add_argument('in_t1_mask', + help='Input t1w mask.') + p.add_argument('out_b0', + help='Output b0 image without distortion.') + + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + assert_inputs_exist(parser, [args.in_b0, args.in_t1]) + assert_outputs_exist(parser, args, args.out_b0) + + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + logging.info('The usage of synthetic b0 is not fully tested.' + 'Be careful when using it.') + + template_img = nib.load(get_synb0_template_path()) + template_data = template_img.get_fdata() + + b0_img = nib.load(args.in_b0) + b0_skull_data = b0_img.get_fdata() + b0_mask_img = nib.load(args.in_b0_mask) + b0_mask_data = b0_mask_img.get_fdata() + + t1_img = nib.load(args.in_t1) + t1_skull_data = t1_img.get_fdata() + t1_mask_img = nib.load(args.in_t1_mask) + t1_mask_data = t1_mask_img.get_fdata() + + b0_bet_data = np.zeros(b0_skull_data.shape) + b0_bet_data[b0_mask_data > 0] = b0_skull_data[b0_mask_data > 0] + t1_bet_data = np.zeros(t1_skull_data.shape) + t1_bet_data[t1_mask_data > 0] = t1_skull_data[t1_mask_data > 0] + + # Crude estimation of the WM mean intensity and normalization + logging.info('Estimating WM mean intensity') + hmrf = TissueClassifierHMRF() + t1_bet_data = gaussian_filter(t1_bet_data, 2) + _, final_segmentation, _ = hmrf.classify(t1_bet_data, 3, 0.25, + tolerance=1e-4, max_iter=5) + avg_wm = np.mean(t1_skull_data[final_segmentation == 3]) + t1_skull_data /= avg_wm + t1_skull_data *= 110 + + # SyNB0 works only in a standard space, so we need to register the images + logging.info('Registering images') + # Use the BET image for registration + t1_bet_to_b0, t1_bet_to_b0_transform = register_image(b0_bet_data, + b0_img.affine, + t1_bet_data, + t1_img.affine, + fine=True) + affine_map = AffineMap(t1_bet_to_b0_transform, + b0_skull_data.shape, b0_img.affine, + t1_skull_data.shape, t1_img.affine) + t1_skull_to_b0 = affine_map.transform(t1_skull_data.astype(np.float64)) + + # Then register to MNI (using the BET again) + _, t1_bet_to_b0_to_mni_transform = register_image(template_data, + template_img.affine, + t1_bet_to_b0, + b0_img.affine, + fine=True) + affine_map = AffineMap(t1_bet_to_b0_to_mni_transform, + template_data.shape, template_img.affine, + b0_skull_data.shape, b0_img.affine) + + # But for prediction, we want the skull + b0_skull_to_mni = affine_map.transform(b0_skull_data.astype(np.float64)) + t1_skull_to_mni = affine_map.transform(t1_skull_to_b0.astype(np.float64)) + + logging.info('Running SyN-B0') + SyNb0 = Synb0(args.verbose) + rev_b0 = SyNb0.predict(b0_skull_to_mni, t1_skull_to_mni) + rev_b0 = affine_map.transform_inverse(rev_b0.astype(np.float64)) + + dtype = b0_img.get_data_dtype() + nib.save(nib.Nifti1Image(rev_b0.astype(dtype), b0_img.affine), + args.out_b0) + + +if __name__ == "__main__": + main() diff --git a/scripts/tests/test_volume_b0_synthesis.py b/scripts/tests/test_volume_b0_synthesis.py new file mode 100644 index 000000000..52bea3fc8 --- /dev/null +++ b/scripts/tests/test_volume_b0_synthesis.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from scilpy.io.fetcher import fetch_data, get_home, get_testing_files_dict +import os +import tempfile + +import pytest +import nibabel as nib +import numpy as np +tensorflow = pytest.importorskip("tensorflow") + + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['others.zip', 'processing.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_volume_b0_synthesis.py', '--help') + assert ret.success + + +@pytest.mark.skipif(tensorflow is None, reason="Tensorflow not installed") +def test_synthesis(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_t1 = os.path.join(get_home(), 'others', + 't1.nii.gz') + in_b0 = os.path.join(get_home(), 'processing', + 'b0_mean.nii.gz') + + t1_img = nib.load(in_t1) + b0_img = nib.load(in_b0) + t1_data = t1_img.get_fdata() + b0_data = b0_img.get_fdata() + t1_data[t1_data > 0] = 1 + b0_data[b0_data > 0] = 1 + nib.save(nib.Nifti1Image(t1_data.astype(np.uint8), t1_img.affine), + 't1_mask.nii.gz') + nib.save(nib.Nifti1Image(b0_data.astype(np.uint8), b0_img.affine), + 'b0_mask.nii.gz') + + ret = script_runner.run('scil_volume_b0_synthesis.py', + in_t1, 't1_mask.nii.gz', + in_b0, 'b0_mask.nii.gz', + 'b0_synthesized.nii.gz', '-v') + assert ret.success