Skip to content

Commit

Permalink
fix conflict in requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
bouj1113 authored and bouj1113 committed Aug 5, 2024
2 parents 2dfa03e + 175eb67 commit 6a4afaf
Show file tree
Hide file tree
Showing 89 changed files with 3,867 additions and 1,451 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ jobs:
- name: Install Scilpy
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install --upgrade pip wheel
python -m pip install --upgrade "setuptools<71.0.0"
python -m pip install -e .
- name: Run tests
Expand Down
16 changes: 9 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,28 @@ bz2file==0.98.*
coloredlogs==15.0.*
cvxpy==1.4.*
cycler==0.11.*
Cython==0.29.*, !=0.29.29
#Cython==0.29.*, !=0.29.29
Cython==3.0.*
dipy==1.9.*
deepdiff==6.3.0
dmri-amico==2.0.*
dmri-commit==2.1.*
dmri-amico==2.0.3
dmri-commit==2.3.0
docopt==0.6.*
dvc==3.48.*
dvc-http==2.32.*
formulaic==0.3.*
fury==0.10.*
future==0.18.*
GitPython==3.1.*
h5py==3.7.*
h5py==3.10.*
joblib==1.2.*
kiwisolver==1.4.*
matplotlib==3.6.*
PyMCubes==0.1.*
nibabel==5.2.*
nilearn==0.9.*
nltk==3.8.*
numpy==1.23.*
numpy==1.25.*
openpyxl==3.0.*
packaging == 23.2.*
Pillow==10.2.*
Expand All @@ -40,9 +42,9 @@ pytz==2022.6.*
requests==2.28.*
scikit-learn==1.2.*
scikit-image==0.22.*
scipy==1.9.*
scipy==1.11.*
six==1.16.*
spams==2.6.*
statsmodels==0.13.*
trimeshpy==0.0.3
trimeshpy==0.0.4
vtk==9.2.*
98 changes: 98 additions & 0 deletions scilpy/image/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os

import numpy as np
from scipy import ndimage as ndi
from scipy.spatial import cKDTree


Expand Down Expand Up @@ -67,6 +68,67 @@ def get_binary_mask_from_labels(atlas, label_list):
return mask


def get_labels_from_mask(mask_data, labels=None, background_label=0):
"""
Get labels from a binary mask which contains multiple blobs. Each blob
will be assigned a label, by default starting from 1. Background will
be assigned the background_label value.
Parameters
----------
mask_data: np.ndarray
The mask data.
labels: list, optional
Labels to assign to each blobs in the mask. Excludes the background
label.
background_label: int
Label for the background.
Returns
-------
label_map: np.ndarray
The labels.
"""
# Get the number of structures and assign labels to each blob
label_map, nb_structures = ndi.label(mask_data)
# Assign labels to each blob if provided
if labels:
# Only keep the first nb_structures labels if the number of labels
# provided is greater than the number of blobs in the mask.
if len(labels) > nb_structures:
logging.warning("Number of labels ({}) does not match the number "
"of blobs in the mask ({}). Only the first {} "
"labels will be used.".format(
len(labels), nb_structures, nb_structures))
# Cannot assign fewer labels than the number of blobs in the mask.
elif len(labels) < nb_structures:
raise ValueError("Number of labels ({}) is less than the number of"
" blobs in the mask ({}).".format(
len(labels), nb_structures))

# Copy the label map to avoid scenarios where the label list contains
# labels that are already present in the label map
custom_label_map = label_map.copy()
# Assign labels to each blob
for idx, label in enumerate(labels[:nb_structures]):
custom_label_map[label_map == idx + 1] = label
label_map = custom_label_map

logging.info('Assigned labels {} to the mask.'.format(
np.unique(label_map[label_map != background_label])))

if background_label != 0 and background_label in label_map:
logging.warning("Background label {} corresponds to a label "
"already in the map. This will cause issues.".format(
background_label))

# Assign background label
if background_label:
label_map[label_map == 0] = background_label

return label_map


def get_lut_dir():
"""
Return LUT directory in scilpy repository
Expand Down Expand Up @@ -379,3 +441,39 @@ def get_stats_in_label(map_data, label_data, label_lut):
'mean': float(mean_seed),
'std': float(std_seed)}
return out_dict


def merge_labels_into_mask(atlas, filtering_args):
"""
Merge labels into a mask.
Parameters
----------
atlas: np.ndarray
Atlas with labels as a numpy array (uint16) to merge.
filtering_args: str
Filtering arguments from the command line.
Return
------
mask: nibabel.nifti1.Nifti1Image
Mask obtained from the combination of multiple labels.
"""
mask = np.zeros(atlas.shape, dtype=np.uint16)

if ' ' in filtering_args:
values = filtering_args.split(' ')
for filter_opt in values:
if ':' in filter_opt:
vals = [int(x) for x in filter_opt.split(':')]
mask[(atlas >= int(min(vals))) & (atlas <= int(max(vals)))] = 1
else:
mask[atlas == int(filter_opt)] = 1
elif ':' in filtering_args:
values = [int(x) for x in filtering_args.split(':')]
mask[(atlas >= int(min(values))) & (atlas <= int(max(values)))] = 1
else:
mask[atlas == int(filtering_args)] = 1

return mask
52 changes: 50 additions & 2 deletions scilpy/image/tests/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import pytest

from scilpy.image.labels import (combine_labels, dilate_labels,
get_data_as_labels, get_lut_dir,
remove_labels, split_labels)
get_data_as_labels, get_labels_from_mask,
get_lut_dir, remove_labels, split_labels)
from scilpy.tests.arrays import ref_in_labels, ref_out_labels


Expand Down Expand Up @@ -132,6 +132,54 @@ def test_get_data_as_labels_float():
_ = get_data_as_labels(img)


def test_get_labels_from_mask():
""" Test get_labels_from_mask with default labels. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
data[data == 2] = 1
data[data == 4] = 2
data[data == 6] = 3
mask = data.astype(bool)

labels = get_labels_from_mask(mask)

assert_equal(labels, data)


def test_get_labels_from_mask_custom_labels_raises():
""" Test get_labels_from_mask with custom labels. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
mask = data.astype(bool)
labels = get_labels_from_mask(mask, [2, 4, 6, 8])

assert np.unique(labels).size == 4 # including background


def test_get_labels_from_mask_custom_labels():
""" Test get_labels_from_mask with custom labels. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
mask = data.astype(bool)

labels = get_labels_from_mask(mask, [2, 4, 6])

assert_equal(labels, data)


def test_get_labels_from_mask_custom_background():
""" test get_labels_from_mask with custom background. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
mask = data.copy().astype(bool)

data[data == 0] = 9

labels = get_labels_from_mask(mask, [2, 4, 6], background_label=9)

assert_equal(labels, data)


def test_get_lut_dir():
lut_dir = get_lut_dir()
assert os.path.isdir(lut_dir)
Expand Down
5 changes: 2 additions & 3 deletions scilpy/io/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def fsl2mrtrix(fsl_bval_filename, fsl_bvec_filename, mrtrix_filename):
logging.warning('WARNING: Your bvecs seem transposed. ' +
'Transposing them.')

shell_idx = [int(np.where(bval == bvals)[0]) for bval in shells]
shell_idx = [int(np.where(bval == bvals)[0][0]) for bval in shells]
save_gradient_sampling_mrtrix(points, shell_idx, bvals,
mrtrix_filename + '.b')

Expand Down Expand Up @@ -62,8 +62,7 @@ def mrtrix2fsl(mrtrix_filename, fsl_filename):
shells = np.array(mrtrix_b[:, 3])

bvals = np.unique(shells).tolist()
shell_idx = [int(np.where(bval == bvals)[0]) for bval in shells]

shell_idx = [int(np.where(bval == bvals)[0][0]) for bval in shells]
save_gradient_sampling_fsl(points, shell_idx, bvals,
filename_bval=fsl_filename + '.bval',
filename_bvec=fsl_filename + '.bvec')
Expand Down
36 changes: 0 additions & 36 deletions scilpy/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,6 @@ def load_img(arg):
return img, dtype


def merge_labels_into_mask(atlas, filtering_args):
"""
Merge labels into a mask.
Parameters
----------
atlas: np.ndarray
Atlas with labels as a numpy array (uint16) to merge.
filtering_args: str
Filtering arguments from the command line.
Return
------
mask: nibabel.nifti1.Nifti1Image
Mask obtained from the combination of multiple labels.
"""
mask = np.zeros(atlas.shape, dtype=np.uint16)

if ' ' in filtering_args:
values = filtering_args.split(' ')
for filter_opt in values:
if ':' in filter_opt:
vals = [int(x) for x in filter_opt.split(':')]
mask[(atlas >= int(min(vals))) & (atlas <= int(max(vals)))] = 1
else:
mask[atlas == int(filter_opt)] = 1
elif ':' in filtering_args:
values = [int(x) for x in filtering_args.split(':')]
mask[(atlas >= int(min(values))) & (atlas <= int(max(values)))] = 1
else:
mask[atlas == int(filtering_args)] = 1

return mask


def assert_same_resolution(images):
"""
Check the resolution of multiple images.
Expand Down
52 changes: 45 additions & 7 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ def link_bundles_and_reference(parser, args, input_tractogram_list):
return bundles_references_tuple


def check_tract_trk(parser, filename):

_, ext = os.path.splitext(filename)
if ext != '.trk':
parser.error('File {} is not a .trk file.'.format(filename))


def check_tracts_same_format(parser, filename_list):
_, ref_ext = os.path.splitext(filename_list[0])

Expand Down Expand Up @@ -667,7 +674,7 @@ def add_compression_arg(p, additional_msg=''):
compress arg.
"""
p.add_argument('--compress', dest='compress_th', nargs='?', const=0.1,
type=ranged_type(float, 0, None),
type=ranged_type(float, 0, None, min_excluded=True),
help='If set, compress the resulting streamline. Value is '
'the maximum \ncompression distance in mm.'
+ additional_msg + '[%(const)s]')
Expand Down Expand Up @@ -1049,9 +1056,12 @@ def parser_color_type(arg):
return f


def ranged_type(value_type, min_value=None, max_value=None):
def ranged_type(value_type, min_value=None, max_value=None,
min_excluded=False, max_excluded=False):
"""Return a function handle of an argument type function for ArgumentParser
checking a range: `min_value` <= arg <= `max_value`.
With min_excluded and max_excluded, the verification becomes
`min_value` < arg < `max_value`.
Parameters
----------
Expand All @@ -1061,6 +1071,10 @@ def ranged_type(value_type, min_value=None, max_value=None):
Minimum acceptable argument value.
max_value : scalar
Maximum acceptable argument value.
min_excluded: bool
If true, the accepted range is ]min_value, max_value].
max_excluded: bool
If true, the accepted range is [min_value, max_value[.
Returns
-------
Expand All @@ -1076,16 +1090,40 @@ def range_checker(arg: str):
f = value_type(arg)
except ValueError:
raise argparse.ArgumentTypeError(f"must be a valid {value_type}")

smaller = np.less
bigger = np.greater
if min_excluded:
smaller = np.less_equal
if max_excluded:
bigger = np.greater_equal

if min_value is not None and max_value is not None:
if f < min_value or f > max_value:
raise argparse.ArgumentTypeError(
f"must be within [{min_value}, {max_value}]")
if smaller(f, min_value) or bigger(f, max_value):
if min_excluded and max_excluded:
raise argparse.ArgumentTypeError(
f"must be within ]{min_value}, {max_value}[")
elif min_excluded:
raise argparse.ArgumentTypeError(
f"must be within ]{min_value}, {max_value}")
elif max_excluded:
raise argparse.ArgumentTypeError(
f"must be within [{min_value}, {max_value}[")
else:
raise argparse.ArgumentTypeError(
f"must be within [{min_value}, {max_value}]")
elif min_value is not None:
if f < min_value:
raise argparse.ArgumentTypeError(f"must be >= {min_value}")
if min_excluded:
raise argparse.ArgumentTypeError(f"must be > {min_value}")
else:
raise argparse.ArgumentTypeError(f"must be >= {min_value}")
elif max_value is not None:
if f > max_value:
raise argparse.ArgumentTypeError(f"must be <= {max_value}")
if max_excluded:
raise argparse.ArgumentTypeError(f"must be < {max_value}")
else:
raise argparse.ArgumentTypeError(f"must be <= {max_value}")
return f

# Return handle to checking function
Expand Down
Loading

0 comments on commit 6a4afaf

Please sign in to comment.