Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working synb0 wrapper in scilpy #734

Merged
merged 30 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e6f31fc
Working synb0
frheault Aug 15, 2023
ae3e84e
Fix affine and added tests
frheault Aug 15, 2023
7a67dce
Working as synb0 github
frheault Aug 16, 2023
8e42ecc
Fine registration
frheault Aug 16, 2023
0adca04
skip test if not TF
frheault Aug 17, 2023
ecb96bb
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
frheault Aug 21, 2023
47e1053
Fix Arnaud last comments
Oct 23, 2023
5e68c83
Merge with master
Oct 23, 2023
d0a991a
Add warning
Oct 23, 2023
1878ef7
merge
frheault Nov 2, 2023
0ec6723
merge
frheault Nov 2, 2023
9986857
Move back fetcher
frheault Nov 2, 2023
41caa25
Merge branch 'test_volume_math_2' of github.com:frheault/scilpy into …
frheault Nov 2, 2023
3e9b2c9
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
frheault Nov 8, 2023
911607c
merge
frheault Nov 8, 2023
c3df582
Added note about test
frheault Nov 8, 2023
ac5787e
Fix merge issue
frheault Nov 8, 2023
4b577fc
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
frheault Dec 4, 2023
ed8be3c
Fix max comment
frheault Dec 4, 2023
7b84414
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
Dec 14, 2023
05e61cd
Merge with masther
Dec 14, 2023
9c34008
Fix conflict
Dec 14, 2023
21e503f
Add back the missing file
Dec 14, 2023
c9eba44
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
frheault Dec 18, 2023
00b0d5a
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
frheault Jan 15, 2024
5080ab2
Add back fetcher
frheault Jan 15, 2024
da0cf1e
Fix conflict
frheault Feb 20, 2024
03c4c2c
Merge branch 'master' of github.com:scilus/scilpy into synb0_integration
frheault Feb 21, 2024
81b5d27
Arnaud comments
frheault Feb 21, 2024
bc117fb
Use proper filename for script
frheault Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
48 changes: 43 additions & 5 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,55 @@ def transform_dwi(reg_obj, static, dwi, interpolation='linear'):


def register_image(static, static_grid2world, moving, moving_grid2world,
transformation_type='affine', dwi=None):
transformation_type='affine', dwi=None, fine=False):
"""
Register a moving image to a static image using either rigid or affine
transformations. If a DWI (4D) is provided, it applies the transformation
to each volume.

Parameters
----------
static : ndarray
The static image volume to which the moving image will be registered.
static_grid2world : ndarray
The grid-to-world (vox2ras) transformation associated with the static
image.
moving : ndarray
The moving image volume that needs to be registered to the static image.
moving_grid2world : ndarray
The grid-to-world (vox2ras) transformation associated with the moving
image.
transformation_type : str, optional
The type of transformation ('rigid' or 'affine'). Default is 'affine'.
dwi : ndarray, optional
Diffusion-weighted imaging data (if applicable). Default is None.
fine : bool, optional
Whether to use fine or coarse settings for the registration.
Default is False.

Raises
------
ValueError
If the transformation_type is neither 'rigid' nor 'affine'.

Returns
-------
ndarray or tuple
If `dwi` is None, returns transformed moving image and transformation
matrix.
If `dwi` is not None, returns transformed DWI and transformation matrix.
"""

if transformation_type not in ['rigid', 'affine']:
raise ValueError('Transformation type not available in Dipy')

# Set all parameters for registration
nbins = 32
nbins = 64 if fine else 32
params0 = None
sampling_prop = None
level_iters = [50, 25, 5]
sigmas = [8.0, 4.0, 2.0]
factors = [8, 4, 2]
level_iters = [250, 100, 50, 25] if fine else [50, 25, 5]
sigmas = [8.0, 4.0, 2.0, 1.0] if fine else [8.0, 4.0, 2.0]
factors = [8, 4, 2, 1.0] if fine else [8, 4, 2]
metric = MutualInformationMetric(nbins, sampling_prop)
reg_obj = AffineRegistration(metric=metric, level_iters=level_iters,
sigmas=sigmas, factors=factors, verbosity=0)
Expand Down
18 changes: 18 additions & 0 deletions scilpy/io/fetcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

import inspect
import logging
import hashlib
import os
Expand Down Expand Up @@ -131,3 +132,20 @@ def fetch_data(files_dict, keys=None):
else:
# toDo. Verify that data on disk is the right one.
logging.warning("Not fetching data; already on disk.")


def get_synb0_template_path():
"""
Return MNI 2.5mm template in scilpy repository
Returns
-------
path: str
Template path
"""
import scilpy # ToDo. Is this the only way?
module_path = inspect.getfile(scilpy)
module_path = os.path.dirname(os.path.dirname(module_path))

path = os.path.join(module_path, 'data/',
'mni_icbm152_t1_tal_nlin_asym_09c_masked_2_5.nii.gz')
return path
12 changes: 10 additions & 2 deletions scilpy/preprocessing/distortion_correction.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-

import logging

import numpy as np


def create_acqparams(readout, encoding_direction, nb_b0s=1, nb_rev_b0s=1):
def create_acqparams(readout, encoding_direction, synb0=False,
nb_b0s=1, nb_rev_b0s=1):
"""
Create acqparams for Topup and Eddy

Expand All @@ -23,13 +26,18 @@ def create_acqparams(readout, encoding_direction, nb_b0s=1, nb_rev_b0s=1):
acqparams: np.array
acqparams
"""
if synb0:
logging.warning('Using SyNb0, untested feature. Be careful.')

acqparams = np.zeros((nb_b0s + nb_rev_b0s, 4))
acqparams[:, 3] = readout

enum_direction = {'x': 0, 'y': 1, 'z': 2}
acqparams[0:nb_b0s, enum_direction[encoding_direction]] = 1
if nb_rev_b0s > 0:
acqparams[nb_b0s:, enum_direction[encoding_direction]] = -1
val = -1 if not synb0 else 1
acqparams[nb_b0s:, enum_direction[encoding_direction]] = val
acqparams[nb_b0s:, 3] = readout if not synb0 else 0

return acqparams

Expand Down
5 changes: 2 additions & 3 deletions scilpy/tractanalysis/afd_along_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting):
rdAFD map (weighted if length_weighting)
"""


afd_sum, rd_sum, weights = \
afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis,
length_weighting)
Expand Down Expand Up @@ -112,7 +111,8 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis,

normalization_weights = np.ones_like(seg_lengths)
if length_weighting:
normalization_weights = seg_lengths / np.linalg.norm(fodf.header.get_zooms()[:3])
normalization_weights = seg_lengths / \
np.linalg.norm(fodf.header.get_zooms()[:3])

for vox_idx, closest_vertex_index, norm_weight in zip(vox_indices,
closest_vertex_indices,
Expand All @@ -130,5 +130,4 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis,
weight_map[vox_idx] += norm_weight

rd_sum_map[rd_sum_map < 0.] = 0.

return afd_sum_map, rd_sum_map, weight_map
7 changes: 5 additions & 2 deletions scripts/scil_dwi_prepare_topup_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _build_arg_parser():
p.add_argument('--config', default='b02b0.cnf',
help='Topup config file [%(default)s].')

p.add_argument('--synb0', action='store_true',
help='If set, will use SyNb0 custom acqparams file.')

p.add_argument('--encoding_direction', default='y',
choices=['x', 'y', 'z'],
help='Acquisition direction of the forward b0 '
Expand Down Expand Up @@ -119,8 +122,8 @@ def main():
fused_b0s_path = os.path.join(args.out_directory, args.out_b0s)
nib.save(nib.Nifti1Image(fused_b0s, b0_img.affine), fused_b0s_path)

acqparams = create_acqparams(
args.readout, args.encoding_direction, b0.shape[-1], rev_b0.shape[-1])
acqparams = create_acqparams(args.readout, args.encoding_direction,
args.synb0, b0.shape[-1], rev_b0.shape[-1])

if not os.path.exists(args.out_directory):
os.makedirs(args.out_directory)
Expand Down
142 changes: 142 additions & 0 deletions scripts/scil_volume_b0_synthesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Wrapper for SyNb0 available in Dipy, to run it on a single subject.
Requires Skull-Strip b0 and t1w images as input, the script will normalize the
t1w's WM to 110, co-register both images, then register it to the appropriate
template, run SyNb0 and then transform the result back to the original space.

This script must be used carefully, as it is not meant to be used in an
environment with the following dependencies already installed (not default
in Scilpy):
- tensorflow-addons
- tensorrt
- tensorflow
"""


import argparse
import logging
import os
import sys
import warnings

# Disable tensorflow warnings
with warnings.catch_warnings():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.simplefilter("ignore")
from dipy.nn.synb0 import Synb0

from dipy.align.imaffine import AffineMap
from dipy.segment.tissue import TissueClassifierHMRF
import nibabel as nib
import numpy as np
from scipy.ndimage import gaussian_filter

from scilpy.io.fetcher import get_synb0_template_path
from scilpy.io.utils import (add_overwrite_arg,
add_verbose_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.image.volume_operations import register_image


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)
p.add_argument('in_b0',
help='Input b0 image.')
p.add_argument('in_b0_mask',
help='Input b0 mask.')
p.add_argument('in_t1',
help='Input t1w image.')
p.add_argument('in_t1_mask',
help='Input t1w mask.')
p.add_argument('out_b0',
help='Output b0 image without distortion.')

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
assert_inputs_exist(parser, [args.in_b0, args.in_t1])
assert_outputs_exist(parser, args, args.out_b0)

logging.getLogger().setLevel(logging.getLevelName(args.verbose))
logging.info('The usage of synthetic b0 is not fully tested.'
'Be careful when using it.')

template_img = nib.load(get_synb0_template_path())
template_data = template_img.get_fdata()

b0_img = nib.load(args.in_b0)
b0_skull_data = b0_img.get_fdata()
b0_mask_img = nib.load(args.in_b0_mask)
b0_mask_data = b0_mask_img.get_fdata()

t1_img = nib.load(args.in_t1)
t1_skull_data = t1_img.get_fdata()
t1_mask_img = nib.load(args.in_t1_mask)
t1_mask_data = t1_mask_img.get_fdata()

b0_bet_data = np.zeros(b0_skull_data.shape)
b0_bet_data[b0_mask_data > 0] = b0_skull_data[b0_mask_data > 0]
t1_bet_data = np.zeros(t1_skull_data.shape)
t1_bet_data[t1_mask_data > 0] = t1_skull_data[t1_mask_data > 0]

# Crude estimation of the WM mean intensity and normalization
logging.info('Estimating WM mean intensity')
hmrf = TissueClassifierHMRF()
t1_bet_data = gaussian_filter(t1_bet_data, 2)
_, final_segmentation, _ = hmrf.classify(t1_bet_data, 3, 0.25,
tolerance=1e-4, max_iter=5)
avg_wm = np.mean(t1_skull_data[final_segmentation == 3])
t1_skull_data /= avg_wm
t1_skull_data *= 110

# SyNB0 works only in a standard space, so we need to register the images
logging.info('Registering images')
# Use the BET image for registration
t1_bet_to_b0, t1_bet_to_b0_transform = register_image(b0_bet_data,
b0_img.affine,
t1_bet_data,
t1_img.affine,
fine=True)
affine_map = AffineMap(t1_bet_to_b0_transform,
b0_skull_data.shape, b0_img.affine,
t1_skull_data.shape, t1_img.affine)
t1_skull_to_b0 = affine_map.transform(t1_skull_data.astype(np.float64))

# Then register to MNI (using the BET again)
_, t1_bet_to_b0_to_mni_transform = register_image(template_data,
template_img.affine,
t1_bet_to_b0,
b0_img.affine,
fine=True)
affine_map = AffineMap(t1_bet_to_b0_to_mni_transform,
template_data.shape, template_img.affine,
b0_skull_data.shape, b0_img.affine)

# But for prediction, we want the skull
b0_skull_to_mni = affine_map.transform(b0_skull_data.astype(np.float64))
t1_skull_to_mni = affine_map.transform(t1_skull_to_b0.astype(np.float64))

logging.info('Running SyN-B0')
SyNb0 = Synb0(args.verbose)
rev_b0 = SyNb0.predict(b0_skull_to_mni, t1_skull_to_mni)
rev_b0 = affine_map.transform_inverse(rev_b0.astype(np.float64))

dtype = b0_img.get_data_dtype()
nib.save(nib.Nifti1Image(rev_b0.astype(dtype), b0_img.affine),
args.out_b0)


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions scripts/tests/test_volume_b0_synthesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from scilpy.io.fetcher import fetch_data, get_home, get_testing_files_dict
import os
import tempfile

import pytest
import nibabel as nib
import numpy as np
tensorflow = pytest.importorskip("tensorflow")


# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['others.zip', 'processing.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_volume_b0_synthesis.py', '--help')
assert ret.success


@pytest.mark.skipif(tensorflow is None, reason="Tensorflow not installed")
def test_synthesis(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_t1 = os.path.join(get_home(), 'others',
't1.nii.gz')
in_b0 = os.path.join(get_home(), 'processing',
'b0_mean.nii.gz')

t1_img = nib.load(in_t1)
b0_img = nib.load(in_b0)
t1_data = t1_img.get_fdata()
b0_data = b0_img.get_fdata()
t1_data[t1_data > 0] = 1
b0_data[b0_data > 0] = 1
nib.save(nib.Nifti1Image(t1_data.astype(np.uint8), t1_img.affine),
't1_mask.nii.gz')
nib.save(nib.Nifti1Image(b0_data.astype(np.uint8), b0_img.affine),
'b0_mask.nii.gz')

ret = script_runner.run('scil_volume_b0_synthesis.py',
in_t1, 't1_mask.nii.gz',
in_b0, 'b0_mask.nii.gz',
'b0_synthesized.nii.gz', '-v')
assert ret.success