Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nawm generation based on distance map #1026

Merged
merged 14 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ docopt==0.6.*
dvc==3.48.*
dvc-http==2.32.*
formulaic==0.3.*
fury==0.10.*
fury==0.11.*
future==0.18.*
GitPython==3.1.*
h5py==3.10.*
Expand Down
31 changes: 31 additions & 0 deletions scilpy/image/tests/test_volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
compute_distance_map, compute_snr,
crop_volume, flip_volume,
mask_data_with_default_cube,
compute_distance_map,
compute_nawm,
merge_metrics, normalize_metric,
resample_volume, reshape_volume,
register_image)
Expand Down Expand Up @@ -360,3 +362,32 @@ def test_compute_distance_map_wrong_shape():
assert False
except ValueError:
assert True


def test_compute_nawm_3D():
lesion_img = np.zeros((3, 3, 3))
lesion_img[1, 1, 1] = 1

nawm = compute_nawm(lesion_img, nb_ring=0, ring_thickness=2)
assert np.sum(nawm) == 1

try:
nawm = compute_nawm(lesion_img, nb_ring=2, ring_thickness=0)
assert False
except ValueError:
assert True

nawm = compute_nawm(lesion_img, nb_ring=1, ring_thickness=2)
assert np.sum(nawm) == 53


def test_compute_nawm_4D():
lesion_img = np.zeros((10, 10, 10))
lesion_img[4, 4, 4] = 1
lesion_img[2, 2, 2] = 2

nawm = compute_nawm(lesion_img, nb_ring=2, ring_thickness=1)
assert nawm.shape == (10, 10, 10, 2)
val, count = np.unique(nawm[..., 0], return_counts=True)
assert np.array_equal(val, [0, 1, 2, 3])
assert np.array_equal(count, [967, 1, 6, 26])
71 changes: 71 additions & 0 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ def compute_distance_map(mask_1, mask_2, symmetric=False,
If True, compute the symmetric distance map. Default is np.inf
max_distance: float, optional
Maximum distance to consider for kdtree exploration. Default is None.
If you put any value, coordinates further than this value will be
considered as np.inf.

Returns
-------
Expand All @@ -811,3 +813,72 @@ def compute_distance_map(mask_1, mask_2, symmetric=False,
distance_map[np.where(mask_2)] = distance

return distance_map


def compute_nawm(lesion_atlas, nb_ring, ring_thickness, mask=None):
"""
Compute the NAWM (Normal Appearing White Matter) from a lesion map.
The rings go from 2 to nb_ring + 2, with the lesion being 1.

The optional mask is used to compute the rings only in the mask
region. This can be useful to avoid useless computation.

If the lesion_atlas is binary, the output will be 3D. If the lesion_atlas
is a label map, the output will be 4D, with each label having its own NAWM.

Parameters
----------
lesion_atlas: np.ndarray
Lesion map. Can be binary or label map.
nb_ring: int
Number of rings to compute.
ring_thickness: int
Thickness of the rings.
mask: np.ndarray, optional
Mask where to compute the NAWM. Default is None.

Returns
-------
nawm: np.ndarray
NAWM volume(s), 3D if binary lesion map, 4D if label lesion map.

"""
if ring_thickness < 1:
raise ValueError("Ring thickness must be at least 1.")

if np.unique(lesion_atlas).size == 1:
raise ValueError('Input lesion map is empty.')
is_binary = True if np.unique(lesion_atlas).size == 2 else False
labels = np.unique(lesion_atlas)[1:]
nawm = np.zeros(lesion_atlas.shape + (len(labels),), dtype=float)

if mask is None:
mask = np.ones(lesion_atlas.shape, dtype=np.uint8)

max_distance = (nb_ring * ring_thickness) + 1

for i, label in enumerate(labels):
curr_mask = np.zeros(lesion_atlas.shape, dtype=np.uint8)
curr_mask[lesion_atlas == label] = 1
curr_dist_map = compute_distance_map(mask,
curr_mask,
max_distance=max_distance)
curr_dist_map[np.isinf(curr_dist_map)] = 0

# Mask to remember where values were computed
to_increase_mask = np.zeros(lesion_atlas.shape, dtype=np.uint8)
to_increase_mask[curr_dist_map > 0] = 1

# Compute the rings. The lesion should be 1, and the first ring
# should be 2, and the max ring should be nb_ring + 1.
curr_dist_map = np.ceil(curr_dist_map / ring_thickness)
curr_dist_map[to_increase_mask > 0] += 1
curr_dist_map[curr_mask > 0] += 1
curr_dist_map[curr_dist_map > nb_ring + 1] = 0

nawm[..., i] = curr_dist_map

if is_binary:
nawm = np.squeeze(nawm)

return nawm.astype(np.uint16)
125 changes: 125 additions & 0 deletions scripts/scil_lesions_generate_nawm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
The NAWM (Normal Appearing White Matter) is the white matter that is
neighboring a lesion. It is used to compute metrics in the white matter
surrounding lesions.

This script will generate concentric rings around the lesions, with the rings
going from 2 to nb_ring + 2, with the lesion being 1.

The optional mask is used to compute the rings only in the mask
region. This can be useful to avoid useless computation.

If the lesion_atlas is binary, the output will be 3D. If the lesion_atlas
is a label map, the output will be either:
- 4D, with each label having its own NAWM.
- 3D, if using --split_4D and saved into a folder as multiple 3D files.

WARNING: Voxels must be isotropic.
"""

import argparse
import logging
import os

import nibabel as nib
import numpy as np

from scilpy.image.labels import get_data_as_labels
from scilpy.image.volume_operations import compute_nawm
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_outputs_exist,
assert_output_dirs_exist_and_empty,
add_verbose_arg)
from scilpy.utils.filenames import split_name_with_nii


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)

p.add_argument('in_image',
help='Lesions file as mask OR labels (.nii.gz).\n'
'(must be uint8 for mask, uint16 for labels).')
p.add_argument('out_image',
help='Output NAWM file (.nii.gz).\n'
'If using --split_4D, this will be the prefix of the '
'output files.')

p.add_argument('--nb_ring', type=int, default=3,
help='Integer representing the number of rings to be '
'created.')
p.add_argument('--ring_thickness', type=int, default=2,
frheault marked this conversation as resolved.
Show resolved Hide resolved
help='Integer representing the thickness (in voxels) of '
'the rings to be created.')
p.add_argument('--mask',
help='Mask where to compute the NAWM (e.g WM mask).')
p.add_argument('--split_4D', metavar='OUT_DIR',
help='Provided lesions will be split into multiple files.\n'
'The output files will be named using out_image as '
'a prefix.')

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))
frheault marked this conversation as resolved.
Show resolved Hide resolved

if args.nb_ring < 1:
parser.error('The number of rings must be at least 1.')
if args.ring_thickness < 1:
parser.error('The ring thickness must be at least 1.')

assert_inputs_exist(parser, args.in_image, args.mask)
if not args.split_4D:
assert_outputs_exist(parser, args, args.out_image)

lesion_img = nib.load(args.in_image)
lesion_atlas = get_data_as_labels(lesion_img)
voxel_size = lesion_img.header.get_zooms()

if not np.allclose(voxel_size, np.mean(voxel_size)):
raise ValueError('Voxels must be isotropic.')

if args.split_4D and np.unique(lesion_atlas).size <= 2:
raise ValueError('Split only works with multiple lesion labels')
elif args.split_4D:
assert_output_dirs_exist_and_empty(parser, args, args.split_4D)

if not args.split_4D and np.unique(lesion_atlas).size > 2:
logging.warning('The input lesion atlas has multiple labels. '
'Converting to binary.')
lesion_atlas[lesion_atlas > 0] = 1

if args.mask:
mask_img = nib.load(args.mask)
mask_data = get_data_as_mask(mask_img)
else:
mask_data = None

nawm = compute_nawm(lesion_atlas, args.nb_ring, args.ring_thickness,
mask=mask_data)

if args.split_4D:
for i in range(nawm.shape[-1]):
label = np.unique(lesion_atlas)[i+1]
base, ext = split_name_with_nii(args.in_image)
base = os.path.basename(base)
lesion_name = os.path.join(args.split_4D,
f'{base}_nawm_{label}{ext}')
nib.save(nib.Nifti1Image(nawm[..., i], lesion_img.affine),
lesion_name)
else:
nib.save(nib.Nifti1Image(nawm, lesion_img.affine), args.out_image)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion scripts/tests/test_labels_from_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['atlas.zip'])
fetch_data(get_testing_files_dict(), keys=['tractograms.zip'])
tmp_dir = tempfile.TemporaryDirectory()


Expand Down
29 changes: 29 additions & 0 deletions scripts/tests/test_lesions_generate_nawm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import os
import tempfile

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict


# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['atlas.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_lesions_generate_nawm.py', '--help')
assert ret.success


def test_execution_atlas(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_atlas = os.path.join(SCILPY_HOME, 'atlas',
'atlas_freesurfer_v2_single_brainstem.nii.gz')
ret = script_runner.run('scil_lesions_generate_nawm.py', in_atlas,
'nawm.nii.gz', '--nb_ring', '3',
'--ring_thickness', '2')
assert ret.success
8 changes: 7 additions & 1 deletion scripts/tests/test_volume_stats_in_labels.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import tempfile

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

fetch_data(get_testing_files_dict(), keys=['plot.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_volume_stats_in_labels.py', '--help')
assert ret.success


def test_execution(script_runner):
def test_execution(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_map = os.path.join(SCILPY_HOME, 'plot', 'fa.nii.gz')
in_atlas = os.path.join(SCILPY_HOME, 'plot', 'atlas_brainnetome.nii.gz')
atlas_lut = os.path.join(SCILPY_HOME, 'plot', 'atlas_brainnetome.json')
Expand Down