Skip to content

Commit

Permalink
Merge pull request #1026 from frheault/nawm_generation
Browse files Browse the repository at this point in the history
Nawm generation based on distance map
  • Loading branch information
frheault authored Oct 11, 2024
2 parents f690978 + 2ce4f62 commit a07a8ec
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ docopt==0.6.*
dvc==3.48.*
dvc-http==2.32.*
formulaic==0.3.*
fury==0.10.*
fury==0.11.*
future==0.18.*
GitPython==3.1.*
h5py==3.10.*
Expand Down
31 changes: 31 additions & 0 deletions scilpy/image/tests/test_volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
compute_distance_map, compute_snr,
crop_volume, flip_volume,
mask_data_with_default_cube,
compute_distance_map,
compute_nawm,
merge_metrics, normalize_metric,
resample_volume, reshape_volume,
register_image)
Expand Down Expand Up @@ -360,3 +362,32 @@ def test_compute_distance_map_wrong_shape():
assert False
except ValueError:
assert True


def test_compute_nawm_3D():
lesion_img = np.zeros((3, 3, 3))
lesion_img[1, 1, 1] = 1

nawm = compute_nawm(lesion_img, nb_ring=0, ring_thickness=2)
assert np.sum(nawm) == 1

try:
nawm = compute_nawm(lesion_img, nb_ring=2, ring_thickness=0)
assert False
except ValueError:
assert True

nawm = compute_nawm(lesion_img, nb_ring=1, ring_thickness=2)
assert np.sum(nawm) == 53


def test_compute_nawm_4D():
lesion_img = np.zeros((10, 10, 10))
lesion_img[4, 4, 4] = 1
lesion_img[2, 2, 2] = 2

nawm = compute_nawm(lesion_img, nb_ring=2, ring_thickness=1)
assert nawm.shape == (10, 10, 10, 2)
val, count = np.unique(nawm[..., 0], return_counts=True)
assert np.array_equal(val, [0, 1, 2, 3])
assert np.array_equal(count, [967, 1, 6, 26])
71 changes: 71 additions & 0 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ def compute_distance_map(mask_1, mask_2, symmetric=False,
If True, compute the symmetric distance map. Default is np.inf
max_distance: float, optional
Maximum distance to consider for kdtree exploration. Default is None.
If you put any value, coordinates further than this value will be
considered as np.inf.
Returns
-------
Expand All @@ -811,3 +813,72 @@ def compute_distance_map(mask_1, mask_2, symmetric=False,
distance_map[np.where(mask_2)] = distance

return distance_map


def compute_nawm(lesion_atlas, nb_ring, ring_thickness, mask=None):
"""
Compute the NAWM (Normal Appearing White Matter) from a lesion map.
The rings go from 2 to nb_ring + 2, with the lesion being 1.
The optional mask is used to compute the rings only in the mask
region. This can be useful to avoid useless computation.
If the lesion_atlas is binary, the output will be 3D. If the lesion_atlas
is a label map, the output will be 4D, with each label having its own NAWM.
Parameters
----------
lesion_atlas: np.ndarray
Lesion map. Can be binary or label map.
nb_ring: int
Number of rings to compute.
ring_thickness: int
Thickness of the rings.
mask: np.ndarray, optional
Mask where to compute the NAWM. Default is None.
Returns
-------
nawm: np.ndarray
NAWM volume(s), 3D if binary lesion map, 4D if label lesion map.
"""
if ring_thickness < 1:
raise ValueError("Ring thickness must be at least 1.")

if np.unique(lesion_atlas).size == 1:
raise ValueError('Input lesion map is empty.')
is_binary = True if np.unique(lesion_atlas).size == 2 else False
labels = np.unique(lesion_atlas)[1:]
nawm = np.zeros(lesion_atlas.shape + (len(labels),), dtype=float)

if mask is None:
mask = np.ones(lesion_atlas.shape, dtype=np.uint8)

max_distance = (nb_ring * ring_thickness) + 1

for i, label in enumerate(labels):
curr_mask = np.zeros(lesion_atlas.shape, dtype=np.uint8)
curr_mask[lesion_atlas == label] = 1
curr_dist_map = compute_distance_map(mask,
curr_mask,
max_distance=max_distance)
curr_dist_map[np.isinf(curr_dist_map)] = 0

# Mask to remember where values were computed
to_increase_mask = np.zeros(lesion_atlas.shape, dtype=np.uint8)
to_increase_mask[curr_dist_map > 0] = 1

# Compute the rings. The lesion should be 1, and the first ring
# should be 2, and the max ring should be nb_ring + 1.
curr_dist_map = np.ceil(curr_dist_map / ring_thickness)
curr_dist_map[to_increase_mask > 0] += 1
curr_dist_map[curr_mask > 0] += 1
curr_dist_map[curr_dist_map > nb_ring + 1] = 0

nawm[..., i] = curr_dist_map

if is_binary:
nawm = np.squeeze(nawm)

return nawm.astype(np.uint16)
125 changes: 125 additions & 0 deletions scripts/scil_lesions_generate_nawm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
The NAWM (Normal Appearing White Matter) is the white matter that is
neighboring a lesion. It is used to compute metrics in the white matter
surrounding lesions.
This script will generate concentric rings around the lesions, with the rings
going from 2 to nb_ring + 2, with the lesion being 1.
The optional mask is used to compute the rings only in the mask
region. This can be useful to avoid useless computation.
If the lesion_atlas is binary, the output will be 3D. If the lesion_atlas
is a label map, the output will be either:
- 4D, with each label having its own NAWM.
- 3D, if using --split_4D and saved into a folder as multiple 3D files.
WARNING: Voxels must be isotropic.
"""

import argparse
import logging
import os

import nibabel as nib
import numpy as np

from scilpy.image.labels import get_data_as_labels
from scilpy.image.volume_operations import compute_nawm
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_outputs_exist,
assert_output_dirs_exist_and_empty,
add_verbose_arg)
from scilpy.utils.filenames import split_name_with_nii


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)

p.add_argument('in_image',
help='Lesions file as mask OR labels (.nii.gz).\n'
'(must be uint8 for mask, uint16 for labels).')
p.add_argument('out_image',
help='Output NAWM file (.nii.gz).\n'
'If using --split_4D, this will be the prefix of the '
'output files.')

p.add_argument('--nb_ring', type=int, default=3,
help='Integer representing the number of rings to be '
'created.')
p.add_argument('--ring_thickness', type=int, default=2,
help='Integer representing the thickness (in voxels) of '
'the rings to be created.')
p.add_argument('--mask',
help='Mask where to compute the NAWM (e.g WM mask).')
p.add_argument('--split_4D', metavar='OUT_DIR',
help='Provided lesions will be split into multiple files.\n'
'The output files will be named using out_image as '
'a prefix.')

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

if args.nb_ring < 1:
parser.error('The number of rings must be at least 1.')
if args.ring_thickness < 1:
parser.error('The ring thickness must be at least 1.')

assert_inputs_exist(parser, args.in_image, args.mask)
if not args.split_4D:
assert_outputs_exist(parser, args, args.out_image)

lesion_img = nib.load(args.in_image)
lesion_atlas = get_data_as_labels(lesion_img)
voxel_size = lesion_img.header.get_zooms()

if not np.allclose(voxel_size, np.mean(voxel_size)):
raise ValueError('Voxels must be isotropic.')

if args.split_4D and np.unique(lesion_atlas).size <= 2:
raise ValueError('Split only works with multiple lesion labels')
elif args.split_4D:
assert_output_dirs_exist_and_empty(parser, args, args.split_4D)

if not args.split_4D and np.unique(lesion_atlas).size > 2:
logging.warning('The input lesion atlas has multiple labels. '
'Converting to binary.')
lesion_atlas[lesion_atlas > 0] = 1

if args.mask:
mask_img = nib.load(args.mask)
mask_data = get_data_as_mask(mask_img)
else:
mask_data = None

nawm = compute_nawm(lesion_atlas, args.nb_ring, args.ring_thickness,
mask=mask_data)

if args.split_4D:
for i in range(nawm.shape[-1]):
label = np.unique(lesion_atlas)[i+1]
base, ext = split_name_with_nii(args.in_image)
base = os.path.basename(base)
lesion_name = os.path.join(args.split_4D,
f'{base}_nawm_{label}{ext}')
nib.save(nib.Nifti1Image(nawm[..., i], lesion_img.affine),
lesion_name)
else:
nib.save(nib.Nifti1Image(nawm, lesion_img.affine), args.out_image)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion scripts/tests/test_labels_from_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['atlas.zip'])
fetch_data(get_testing_files_dict(), keys=['tractograms.zip'])
tmp_dir = tempfile.TemporaryDirectory()


Expand Down
29 changes: 29 additions & 0 deletions scripts/tests/test_lesions_generate_nawm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import os
import tempfile

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict


# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['atlas.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_lesions_generate_nawm.py', '--help')
assert ret.success


def test_execution_atlas(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_atlas = os.path.join(SCILPY_HOME, 'atlas',
'atlas_freesurfer_v2_single_brainstem.nii.gz')
ret = script_runner.run('scil_lesions_generate_nawm.py', in_atlas,
'nawm.nii.gz', '--nb_ring', '3',
'--ring_thickness', '2')
assert ret.success
8 changes: 7 additions & 1 deletion scripts/tests/test_volume_stats_in_labels.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import tempfile

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

fetch_data(get_testing_files_dict(), keys=['plot.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_volume_stats_in_labels.py', '--help')
assert ret.success


def test_execution(script_runner):
def test_execution(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_map = os.path.join(SCILPY_HOME, 'plot', 'fa.nii.gz')
in_atlas = os.path.join(SCILPY_HOME, 'plot', 'atlas_brainnetome.nii.gz')
atlas_lut = os.path.join(SCILPY_HOME, 'plot', 'atlas_brainnetome.json')
Expand Down

0 comments on commit a07a8ec

Please sign in to comment.