diff --git a/scilpy/tests/utils.py b/scilpy/tests/utils.py new file mode 100644 index 000000000..19777aebe --- /dev/null +++ b/scilpy/tests/utils.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +import numpy as np + + +def nan_array_equal(a, b): + a = np.asarray(a) + b = np.asarray(b) + + nan_a = np.argwhere(np.isnan(a)) + nan_b = np.argwhere(np.isnan(a)) + + a = a[~np.isnan(a)] + b = b[~np.isnan(b)] + return np.array_equal(a, b) and np.array_equal(nan_a, nan_b) diff --git a/scilpy/tractanalysis/connectivity_segmentation.py b/scilpy/tractanalysis/connectivity_segmentation.py index 4fbb19718..3c236b958 100644 --- a/scilpy/tractanalysis/connectivity_segmentation.py +++ b/scilpy/tractanalysis/connectivity_segmentation.py @@ -15,7 +15,7 @@ compute_streamline_segment from scilpy.tractograms.streamline_operations import \ (remove_loops as perform_remove_loops, - remove_shap_turns_qb, + remove_sharp_turns_qb, remove_streamlines_with_overlapping_points, filter_streamlines_by_length) @@ -341,7 +341,7 @@ def construct_hdf5_from_connectivity( if remove_curv_dev: logging.debug("- Step 4: Removing sharp turns (Qb threshold: {})" .format(curv_qb_distance)) - no_qb_curv_ids = remove_shap_turns_qb( + no_qb_curv_ids = remove_sharp_turns_qb( current_sft.streamlines, qb_threshold=curv_qb_distance) qb_curv_ids = np.setdiff1d(np.arange(len(current_sft)), no_qb_curv_ids) diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index 0a0816fa1..bf40e4350 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -307,10 +307,8 @@ def perform_operation_dpp_to_dps(op_name, sft, dpp_name, endpoints_only=False): if endpoints_only: new_data_per_streamline = [] for s in sft.data_per_point[dpp_name]: - start = s[0] - end = s[-1] - concat = np.concatenate((start[:], end[:])) - new_data_per_streamline.append(call_op(concat)) + fake_s = np.asarray([s[0], s[-1]]) + new_data_per_streamline.append(call_op(fake_s)) else: new_data_per_streamline = [] for s in sft.data_per_point[dpp_name]: diff --git a/scilpy/tractograms/lazy_tractogram_operations.py b/scilpy/tractograms/lazy_tractogram_operations.py index dcf6dfe9a..04498f34d 100644 --- a/scilpy/tractograms/lazy_tractogram_operations.py +++ b/scilpy/tractograms/lazy_tractogram_operations.py @@ -31,11 +31,14 @@ def lazy_streamlines_count(in_tractogram_path): tractogram_file = nib.streamlines.load(in_tractogram_path, lazy_load=True) - return tractogram_file.header[key] + return int(tractogram_file.header[key]) def lazy_concatenate(in_tractograms, out_ext): """ + Concatenates tractograms, if they can be concatenated. Headers must be + compatible. + Parameters ---------- in_tractograms: list @@ -45,7 +48,7 @@ def lazy_concatenate(in_tractograms, out_ext): Returns ------- - out_tractogram: Lazy tractogram + out_tractogram: LazyTractogram The concatenated data header: nibabel header or None Depending on the data type. diff --git a/scilpy/tractograms/streamline_operations.py b/scilpy/tractograms/streamline_operations.py index f1cddc883..2036fd98d 100644 --- a/scilpy/tractograms/streamline_operations.py +++ b/scilpy/tractograms/streamline_operations.py @@ -92,36 +92,52 @@ def _get_point_on_line(first_point, second_point, vox_lower_corner): return first_point + ray * (t0 + t1) / 2. -def get_angles(sft): - """Color streamlines according to their length. +def get_angles(sft, degrees=True, add_zeros=False): + """ + Returns the angle between each segment of the streamlines. Parameters ---------- sft: StatefulTractogram - The tractogram. + The tractogram, with N streamlines. + degrees: bool + If True, returns angles in degree. Else, in radian. + add_zeros: bool + For a streamline of length M, there are M-1 segments, and M-2 angles. + If add_zeros is set to True, a 0 angle is added at both ends of the + returned values, to get M values. Returns ------- angles: list[np.ndarray] - The angles per streamline, in degree. + List of N numpy arrays. The angles per streamline, in degree. """ angles = [] for i in range(len(sft.streamlines)): dirs = np.diff(sft.streamlines[i], axis=0) dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True) cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1) + # Resolve numerical instability cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0) - line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0] - angles.extend(line_angles) + line_angles = list(np.arccos(cos_angles)) + + if add_zeros: + line_angles = [0.0] + line_angles + [0.0] + + if degrees: + line_angles = np.rad2deg(line_angles) - angles = np.rad2deg(angles) + angles.append(line_angles) return angles -def get_values_along_length(sft): - """Get the streamlines' coordinate positions according to their length. +def get_streamlines_as_linspaces(sft): + """ + For each streamline, returns a list of M values ranging between 0 and 1, + where M is the number of points in the streamline. This gives the position + of each coordinate per respect to the streamline's length. Parameters ---------- @@ -134,8 +150,8 @@ def get_values_along_length(sft): For each streamline, the linear distribution of its length. """ positions = [] - for i in range(len(sft.streamlines)): - positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i])))) + for s in sft.streamlines: + positions.append(list(np.linspace(0, 1, len(s)))) return positions @@ -208,6 +224,8 @@ def cut_invalid_streamlines(sft, epsilon=0.001): ---------- sft: StatefulTractogram The sft to remove invalid points from. + epsilon: float + Error allowed when verifying the bounding box. Returns ------- @@ -227,7 +245,7 @@ def cut_invalid_streamlines(sft, epsilon=0.001): sft.to_corner() copy_sft = copy.deepcopy(sft) - indices_to_remove, _ = copy_sft.remove_invalid_streamlines() + indices_to_cut, _ = copy_sft.remove_invalid_streamlines() new_streamlines = [] new_data_per_point = {} @@ -239,21 +257,27 @@ def cut_invalid_streamlines(sft, epsilon=0.001): cutting_counter = 0 for ind in range(len(sft.streamlines)): - # No reason to try to cut if all points are within the volume - if ind in indices_to_remove: - best_pos = [0, 0] - cur_pos = [0, 0] + if ind in indices_to_cut: + # This streamline was detected as invalid + + best_pos = [0, 0] # First and last valid points of longest segment + cur_pos = [0, 0] # First and last valid points of current segment for pos, point in enumerate(sft.streamlines[ind]): if (point < epsilon).any() or \ (point >= sft.dimensions - epsilon).any(): + # The coordinate is < 0 or > box. Starting new segment cur_pos = [pos+1, pos+1] - if cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]: + elif cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]: + # We found a longer good segment. best_pos = cur_pos + + # Ready to check next point. cur_pos[1] += 1 + # Appending the longest segment to the list of streamlines if not best_pos == [0, 0]: new_streamlines.append( - sft.streamlines[ind][best_pos[0]:best_pos[1]-1]) + sft.streamlines[ind][best_pos[0]:best_pos[1]]) cutting_counter += 1 for key in sft.data_per_streamline.keys(): new_data_per_streamline[key].append( @@ -263,8 +287,9 @@ def cut_invalid_streamlines(sft, epsilon=0.001): sft.data_per_point[key][ind][ best_pos[0]:best_pos[1]-1]) else: - logging.warning('Streamlines entirely out of the volume.') + logging.warning('Streamline entirely out of the volume.') else: + # No reason to try to cut if all points are within the volume new_streamlines.append(sft.streamlines[ind]) for key in sft.data_per_streamline.keys(): new_data_per_streamline[key].append( @@ -310,7 +335,8 @@ def remove_single_point_streamlines(sft): def remove_overlapping_points_streamlines(sft, threshold=0.001): """ - Remove overlapping points from streamlines in a StatefulTractogram. + Remove overlapping points from streamlines in a StatefulTractogram, i.e. + points forming a segment of length < threshold in a given streamline. Parameters ---------- @@ -819,7 +845,7 @@ def remove_loops(streamlines, max_angle, num_processes=1): return ids, streamlines_clean -def remove_shap_turns_qb(streamlines, qb_threshold=15.0, qb_seed=0): +def remove_sharp_turns_qb(streamlines, qb_threshold=15.0, qb_seed=0): """ Remove sharp turns from a list of streamlines. Should only be used on bundled streamlines, not on whole-brain tractograms. @@ -893,7 +919,7 @@ def remove_loops_and_sharp_turns(streamlines, max_angle, qb_threshold=None, num_processes) if qb_threshold is not None: - ids = remove_shap_turns_qb(streamlines_clean, qb_threshold, qb_seed) + ids = remove_sharp_turns_qb(streamlines_clean, qb_threshold, qb_seed) return ids diff --git a/scilpy/tractograms/tests/test_dps_and_dpp_management.py b/scilpy/tractograms/tests/test_dps_and_dpp_management.py index c5f323e2e..c32f2eace 100644 --- a/scilpy/tractograms/tests/test_dps_and_dpp_management.py +++ b/scilpy/tractograms/tests/test_dps_and_dpp_management.py @@ -4,6 +4,7 @@ from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin from scilpy.image.volume_space_management import DataVolume +from scilpy.tests.utils import nan_array_equal from scilpy.tractograms.dps_and_dpp_management import ( add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines, project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps, @@ -26,18 +27,6 @@ def _get_small_sft(): return fake_sft -def nan_array_equal(a, b): - a = np.asarray(a) - b = np.asarray(b) - - nan_a = np.argwhere(np.isnan(a)) - nan_b = np.argwhere(np.isnan(a)) - - a = a[~np.isnan(a)] - b = b[~np.isnan(b)] - return np.array_equal(a, b) and np.array_equal(nan_a, nan_b) - - def test_add_data_as_color_dpp(): lut = get_lookup_table('viridis') @@ -143,13 +132,28 @@ def test_project_dpp_to_map(): fake_sft = _get_small_sft() fake_sft.data_per_point['my_dpp'] = [[1]*3, [2]*4] - map_data = project_dpp_to_map(fake_sft, 'my_dpp', sum_lines=True) + # Average + map_data = project_dpp_to_map(fake_sft, 'my_dpp') + expected = np.zeros((3, 3, 3)) # fake_ref is 3x3x3 + expected[0, 0, 0] = 1 # the 3 points of the first streamline + expected[1, 1, 1] = 2 # the 4 points of the second streamline + assert np.array_equal(map_data, expected) + # Sum + map_data = project_dpp_to_map(fake_sft, 'my_dpp', sum_lines=True) expected = np.zeros((3, 3, 3)) # fake_ref is 3x3x3 expected[0, 0, 0] = 3 * 1 # the 3 points of the first streamline expected[1, 1, 1] = 4 * 2 # the 4 points of the second streamline assert np.array_equal(map_data, expected) + # Option 'endpoints_only': + map_data = project_dpp_to_map(fake_sft, 'my_dpp', sum_lines=True, + endpoints_only=True) + expected = np.zeros((3, 3, 3)) # fake_ref is 3x3x3 + expected[0, 0, 0] = 2 * 1 # only 2 points of the first streamline + expected[1, 1, 1] = 2 * 2 # only 2 points of the second streamline + assert np.array_equal(map_data, expected) + def test_perform_operation_on_dpp(): fake_sft = _get_small_sft() @@ -176,12 +180,23 @@ def test_perform_operation_on_dpp(): assert np.array_equal(dpp[0].squeeze(), [1] * 3) assert np.array_equal(dpp[1].squeeze(), [2] * 4) + # Option 'endpoints only': + dpp = perform_operation_on_dpp('max', fake_sft, 'my_dpp', + endpoints_only=True) + assert nan_array_equal(dpp[0].squeeze(), [1.0, np.nan, 1]) + assert nan_array_equal(dpp[1].squeeze(), [2.0, np.nan, 2]) + def test_perform_operation_dpp_to_dps(): fake_sft = _get_small_sft() + + # This fake dpp contains two values per point: [1, 0] at each point for the + # first streamline (length 3), [2, 0] for the second (length 4). fake_sft.data_per_point['my_dpp'] = [[[1, 0]]*3, [[2, 0]]*4] + # Operations are done separately for each value. + # Mean: dps = perform_operation_dpp_to_dps('mean', fake_sft, 'my_dpp') assert np.array_equal(dps[0], [1, 0]) @@ -202,6 +217,12 @@ def test_perform_operation_dpp_to_dps(): assert np.array_equal(dps[0], [1, 0]) assert np.array_equal(dps[1], [2, 0]) + # Option 'endpoints_only': + dps = perform_operation_dpp_to_dps('sum', fake_sft, 'my_dpp', + endpoints_only=True) + assert np.array_equal(dps[0], [2 * 1, 0]) + assert np.array_equal(dps[1], [2 * 2, 0]) + def test_perform_correlation_on_endpoints(): fake_sft = _get_small_sft() diff --git a/scilpy/tractograms/tests/test_lazy_tractogram_operations.py b/scilpy/tractograms/tests/test_lazy_tractogram_operations.py new file mode 100644 index 000000000..8c6380828 --- /dev/null +++ b/scilpy/tractograms/tests/test_lazy_tractogram_operations.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +import os + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict +from scilpy.tractograms.lazy_tractogram_operations import \ + lazy_streamlines_count, lazy_concatenate + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['tractograms.zip']) +main_path = os.path.join(SCILPY_HOME, 'tractograms', 'streamline_operations') + + +def test_lazy_tractogram_count(): + in_file = os.path.join(main_path, 'bundle_4.tck') + nb = lazy_streamlines_count(in_file) + assert nb == 10 + + +def test_lazy_concatenate(): + in_file1 = os.path.join(main_path, 'bundle_4.tck') + in_file2 = os.path.join(main_path, 'bundle_4_cut_endpoints.tck') + + out_trk, out_header = lazy_concatenate([in_file1, in_file2], '.tck') + assert len(out_trk) == 20 diff --git a/scilpy/tractograms/tests/test_streamline_operations.py b/scilpy/tractograms/tests/test_streamline_operations.py index 8b2b6c5df..ccddcc972 100644 --- a/scilpy/tractograms/tests/test_streamline_operations.py +++ b/scilpy/tractograms/tests/test_streamline_operations.py @@ -2,9 +2,6 @@ import os import tempfile -from dipy.io.streamline import load_tractogram -from dipy.tracking.streamlinespeed import length -import nibabel as nib import numpy as np from numpy.testing import assert_array_almost_equal import pytest @@ -14,80 +11,129 @@ from scilpy import SCILPY_HOME from scilpy.io.fetcher import fetch_data, get_testing_files_dict from scilpy.tractograms.streamline_operations import ( + compress_sft, + cut_invalid_streamlines, filter_streamlines_by_length, filter_streamlines_by_total_length_per_dim, + get_angles, + get_streamlines_as_linspaces, resample_streamlines_num_points, resample_streamlines_step_size, smooth_line_gaussian, smooth_line_spline, - parallel_transport_streamline) + parallel_transport_streamline, + remove_overlapping_points_streamlines, + remove_single_point_streamlines) from scilpy.tractograms.tractogram_operations import concatenate_sft - fetch_data(get_testing_files_dict(), keys=['tractograms.zip']) tmp_dir = tempfile.TemporaryDirectory() +# Streamlines and masks relevant to the tests here. +test_files_path = os.path.join(SCILPY_HOME, 'tractograms', + 'streamline_operations') +in_long_sft = os.path.join(test_files_path, 'bundle_4.tck') +in_mid_sft = os.path.join(test_files_path, 'bundle_4_cut_endpoints.tck') +in_short_sft = os.path.join(test_files_path, 'bundle_4_cut_center.tck') +in_ref = os.path.join(test_files_path, 'bundle_4_wm.nii.gz') +in_rois = os.path.join(test_files_path, 'bundle_4_head_tail_offset.nii.gz') -def _setup_files(): - """ Load streamlines and masks relevant to the tests here. - """ - - os.chdir(os.path.expanduser(tmp_dir.name)) - in_long_sft = os.path.join(SCILPY_HOME, 'tractograms', - 'streamline_operations', - 'bundle_4.tck') - in_mid_sft = os.path.join(SCILPY_HOME, 'tractograms', - 'streamline_operations', - 'bundle_4_cut_endpoints.tck') - in_short_sft = os.path.join(SCILPY_HOME, 'tractograms', - 'streamline_operations', - 'bundle_4_cut_center.tck') - in_ref = os.path.join(SCILPY_HOME, 'tractograms', - 'streamline_operations', - 'bundle_4_wm.nii.gz') - - in_rois = os.path.join(SCILPY_HOME, 'tractograms', - 'streamline_operations', - 'bundle_4_head_tail_offset.nii.gz') - - # Load sft - long_sft = load_tractogram(in_long_sft, in_ref) - mid_sft = load_tractogram(in_mid_sft, in_ref) - short_sft = load_tractogram(in_short_sft, in_ref) - sft = concatenate_sft([long_sft, mid_sft, short_sft]) +def test_get_angles(): + fake_straight_line = np.asarray([[0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [3, 3, 3]], dtype=float) + fake_ninety_degree = np.asarray([[0, 0, 0], + [1, 1, 0], + [0, 2, 0]], dtype=float) - # Load mask - rois = nib.load(in_rois) - return sft, rois + sft = load_tractogram(in_short_sft, in_ref) + sft.streamlines = [fake_straight_line, fake_ninety_degree] + angles = get_angles(sft) + assert np.array_equal(angles[0], [0, 0]) + assert np.array_equal(angles[1], [90]) -def test_angles(): - # toDo - pass + angles = get_angles(sft, add_zeros=True) + assert np.array_equal(angles[0], [0, 0, 0, 0]) + assert np.array_equal(angles[1], [0, 90, 0]) -def test_get_values_along_length(): - # toDo - pass +def test_get_streamlines_as_linspaces(): + sft = load_tractogram(in_short_sft, in_ref) + lines = get_streamlines_as_linspaces(sft) + assert len(lines) == len(sft) + assert len(lines[0]) == len(sft.streamlines[0]) + assert lines[0][0] == 0 + assert lines[0][-1] == 1 def test_compress_sft(): - # toDo - pass + sft = load_tractogram(in_long_sft, in_ref) + compressed = compress_sft(sft, tol_error=0.01) + assert len(sft) == len(compressed) + + for s, sc in zip(sft.streamlines, compressed.streamlines): + # All streamlines should be shorter once compressed + assert len(sc) <= len(s) + + # Streamlines' endpoints should not be changed + assert np.allclose(s[0], sc[0]) + assert np.allclose(s[-1], sc[-1]) + + # Not testing more than that, as it uses Dipy's method, tested by Dipy def test_cut_invalid_streamlines(): - # toDo - pass + sft = load_tractogram(in_long_sft, in_ref) + sft.to_vox() + cut, nb = cut_invalid_streamlines(sft) + assert len(cut) == len(sft) + assert nb == 0 -def test_filter_streamlines_by_length_max_length(): - """ Test the filter_streamlines_by_length function with a max length. - """ + # Faking an invalid streamline. Currently, volume is 64x64x3 + sft.streamlines[0][-1, :] = [65.0, 65.0, 2.0] + cut, nb = cut_invalid_streamlines(sft) + assert len(cut) == len(sft) + assert np.all([len(sc) <= len(s) for s, sc in + zip(sft.streamlines, cut.streamlines)]) + assert len(cut.streamlines[0]) == len(sft.streamlines[0]) - 1 + assert nb == 1 + + +def test_remove_single_point_streamlines(): + sft = load_tractogram(in_short_sft, in_ref) + + # Adding a one-point streamline + sft.streamlines.append([[7, 7, 7]]) + new_sft = remove_single_point_streamlines(sft) + assert len(new_sft) == len(sft) - 1 - sft, _ = _setup_files() +def test_remove_overlapping_points_streamlines(): + sft = load_tractogram(in_short_sft, in_ref) + + fake_line = np.asarray([[3, 3, 3], + [4, 4, 4], + [5, 5, 5], + [5, 5, 5.00000001]], dtype=float) + sft.streamlines.append(fake_line) + + new_sft = remove_overlapping_points_streamlines(sft) + assert len(new_sft.streamlines[-1]) == len(sft.streamlines[-1]) - 1 + assert np.all([len(new_sft.streamlines[i]) == len(sft.streamlines[i]) for + i in range(len(sft) - 1)]) + + +def test_filter_streamlines_by_length(): + long_sft = load_tractogram(in_long_sft, in_ref) + mid_sft = load_tractogram(in_mid_sft, in_ref) + short_sft = load_tractogram(in_short_sft, in_ref) + sft = concatenate_sft([long_sft, mid_sft, short_sft]) + + # === 1. Using max length === min_length = 0. max_length = 100 # Filter streamlines by length and get the lengths @@ -100,13 +146,7 @@ def test_filter_streamlines_by_length_max_length(): assert np.all(lengths <= max_length) - -def test_filter_streamlines_by_length_min_length(): - """ Test the filter_streamlines_by_length function with a min length. - """ - - sft, _ = _setup_files() - + # === 2. Using min length === min_length = 100 max_length = np.inf @@ -120,14 +160,7 @@ def test_filter_streamlines_by_length_min_length(): # Test that streamlines shorter than 100 were removed. assert np.all(lengths >= min_length) - -def test_filter_streamlines_by_length_min_and_max_length(): - """ Test the filter_streamlines_by_length function with a min - and max length. - """ - - sft, _ = _setup_files() - + # === 3. Using both min and max length === min_length = 100 max_length = 120 @@ -142,21 +175,18 @@ def test_filter_streamlines_by_length_min_and_max_length(): assert np.all(lengths >= min_length) and np.all(lengths <= max_length) -def test_filter_streamlines_by_total_length_per_dim_x(): - """ Test the filter_streamlines_by_total_length_per_dim function. - This function is quite awkward to test without reimplementing - the logic, but luckily we have data going purely left-right. - - This test also tests the return of rejected streamlines. - """ - - # Streamlines are going purely left-right, so the - # x dimension should have the longest span. - sft, _ = _setup_files() +def test_filter_streamlines_by_total_length_per_dim(): + long_sft = load_tractogram(in_long_sft, in_ref) + mid_sft = load_tractogram(in_mid_sft, in_ref) + short_sft = load_tractogram(in_short_sft, in_ref) + sft = concatenate_sft([long_sft, mid_sft, short_sft]) min_length = 115 max_length = 125 + # === 1. Test x dimension === + # Test sft has streamlines that are going purely left-right, so the x + # dimension should have the longest span. constraint = [min_length, max_length] inf_constraint = [-np.inf, np.inf] @@ -164,7 +194,7 @@ def test_filter_streamlines_by_total_length_per_dim_x(): # No rejected streamlines should be returned filtered_sft, ids, rejected = filter_streamlines_by_total_length_per_dim( sft, constraint, inf_constraint, inf_constraint, - True, False) + use_abs=True, save_rejected=False) lengths = length(filtered_sft.streamlines) # Test that streamlines were removed and that the test is not bogus. @@ -174,27 +204,12 @@ def test_filter_streamlines_by_total_length_per_dim_x(): # No rejected streamlines should have been returned assert rejected is None - -def test_filter_streamlines_by_total_length_per_dim_y(): - """ Test the filter_streamlines_by_total_length_per_dim function. - This function is quite awkward to test without reimplementing - the logic. We rotate the streamlines to be purely up-down. - - This test also tests the return of rejected streamlines. The rejected - streamlines should have "invalid" lengths. - """ - - # Streamlines are going purely left-right, so the - # streamlines have to be rotated to be purely up-down. - sft, _ = _setup_files() + # === 2. Testing y dimension === # Rotate streamlines by swapping x and y for all streamlines swapped_streamlines_y = [s[:, [1, 0, 2]] for s in sft.streamlines] sft_y = sft.from_sft(swapped_streamlines_y, sft) - min_length = 115 - max_length = 125 - constraint = [min_length, max_length] inf_constraint = [-np.inf, np.inf] @@ -211,17 +226,7 @@ def test_filter_streamlines_by_total_length_per_dim_y(): assert np.all(np.logical_or(min_length > rejected_lengths, rejected_lengths > max_length)) - -def test_filter_streamlines_by_total_length_per_dim_z(): - """ Test the filter_streamlines_by_total_length_per_dim function. - This function is quite awkward to test without reimplementing - the logic. - """ - - # Streamlines are going purely left-right, so the - # streamlines have to be rotated to be purely front-back. - sft, _ = _setup_files() - + # === 3. Testing z dimension === # Rotate streamlines by swapping x and z for all streamlines swapped_streamlines_y = [s[:, [2, 1, 0]] for s in sft.streamlines] sft_y = sft.from_sft(swapped_streamlines_y, sft) @@ -240,42 +245,38 @@ def test_filter_streamlines_by_total_length_per_dim_z(): # Test that streamlines were removed and that the test is not bogus. assert len(filtered_sft) < len(sft) - assert np.all(lengths >= min_length) and np.all(lengths <= max_length) -def test_resample_streamlines_num_points_2(): - """ Test the resample_streamlines_num_points function to 2 points. - """ +def test_resample_streamlines_num_points(): + long_sft = load_tractogram(in_long_sft, in_ref) + mid_sft = load_tractogram(in_mid_sft, in_ref) + short_sft = load_tractogram(in_short_sft, in_ref) + sft = concatenate_sft([long_sft, mid_sft, short_sft]) - sft, _ = _setup_files() + # Test 1. To two points nb_points = 2 - resampled_sft = resample_streamlines_num_points(sft, nb_points) lengths = [len(s) == nb_points for s in resampled_sft.streamlines] - assert np.all(lengths) - -def test_resample_streamlines_num_points_1000(): - """ Test the resample_streamlines_num_points function to 1000 points. - """ - - sft, _ = _setup_files() + # Test 2. To 1000 points. nb_points = 1000 - resampled_sft = resample_streamlines_num_points(sft, nb_points) lengths = [len(s) == nb_points for s in resampled_sft.streamlines] assert np.all(lengths) -def test_resample_streamlines_step_size_1mm(): +def test_resample_streamlines_step_size(): """ Test the resample_streamlines_step_size function to 1mm. """ + long_sft = load_tractogram(in_long_sft, in_ref) + mid_sft = load_tractogram(in_mid_sft, in_ref) + short_sft = load_tractogram(in_short_sft, in_ref) + sft = concatenate_sft([long_sft, mid_sft, short_sft]) - sft, _ = _setup_files() - + # Test 1. To 1 mm step_size = 1.0 resampled_sft = resample_streamlines_step_size(sft, step_size) @@ -286,13 +287,7 @@ def test_resample_streamlines_step_size_1mm(): # Tolerance of 10% of the step size assert np.allclose(steps, step_size, atol=0.1), steps - -def test_resample_streamlines_step_size_01mm(): - """ Test the resample_streamlines_step_size function to 0.1mm. - """ - - sft, _ = _setup_files() - + # Test 2. To 0.1 mm step_size = 0.1 resampled_sft = resample_streamlines_step_size(sft, step_size) @@ -310,7 +305,7 @@ def test_smooth_line_gaussian_error(): value of 0, therefore it should throw and error. """ - sft, _ = _setup_files() + sft = load_tractogram(in_long_sft, in_ref) streamline = sft.streamlines[0] # Add noise to the streamline @@ -326,8 +321,7 @@ def test_smooth_line_gaussian(): streamline and smoothing it. The smoothed streamline should be closer to the original streamline than the noisy one. """ - - sft, _ = _setup_files() + sft = load_tractogram(in_long_sft, in_ref) streamline = sft.streamlines[0] rng = np.random.default_rng(1337) @@ -354,8 +348,7 @@ def test_smooth_line_spline_error(): streamline and smoothing it. The function does not accept a sigma value of 0, therefore it should throw and error. """ - - sft, _ = _setup_files() + sft = load_tractogram(in_long_sft, in_ref) streamline = sft.streamlines[0] # Add noise to the streamline @@ -371,8 +364,7 @@ def test_smooth_line_spline(): streamline and smoothing it. The smoothed streamline should be closer to the original streamline than the noisy one. """ - - sft, _ = _setup_files() + sft = load_tractogram(in_short_sft, in_ref) streamline = sft.streamlines[-1] rng = np.random.default_rng(1337) @@ -400,7 +392,7 @@ def test_generate_matched_points(): def test_parallel_transport_streamline(): - sft, _ = _setup_files() + sft = load_tractogram(in_long_sft, in_ref) streamline = sft.streamlines[0] rng = np.random.default_rng(3018) @@ -417,3 +409,19 @@ def test_parallel_transport_streamline(): decimal=4) assert [len(s) for s in pt_streamlines] == [130] * 20 assert len(pt_streamlines) == 20 + + +def test_remove_loops(): + # toDO + # Coverage will not work: uses multi-processing + pass + + +def test_remove_sharp_turns_qb(): + # toDO + pass + + +def test_remove_loops_and_sharp_turns(): + # ok. Just a combination of the two previous functions. + pass diff --git a/scripts/scil_tractogram_assign_custom_color.py b/scripts/scil_tractogram_assign_custom_color.py index 5ebc38144..c5cb39bd6 100755 --- a/scripts/scil_tractogram_assign_custom_color.py +++ b/scripts/scil_tractogram_assign_custom_color.py @@ -63,8 +63,8 @@ assert_outputs_exist, load_matrix_in_any_format) from scilpy.tractograms.dps_and_dpp_management import add_data_as_color_dpp -from scilpy.tractograms.streamline_operations import (get_values_along_length, - get_angles) +from scilpy.tractograms.streamline_operations import ( + get_streamlines_as_linspaces, get_angles) from scilpy.viz.color import get_lookup_table from scilpy.viz.color import prepare_colorbar_figure @@ -204,9 +204,11 @@ def main(): data = nib.load(args.from_anatomy).get_fdata() data = map_coordinates(data, concat_points, order=0) elif args.along_profile: - data = get_values_along_length(sft) + data = get_streamlines_as_linspaces(sft) + data = np.hstack(data) else: # args.local_angle: - data = get_angles(sft) + data = get_angles(sft, add_zeros=True) + data = np.hstack(data) # Processing sft, lbound, ubound = add_data_as_color_dpp( diff --git a/scripts/scil_tractogram_count_streamlines.py b/scripts/scil_tractogram_count_streamlines.py index b42005993..b83503d7c 100755 --- a/scripts/scil_tractogram_count_streamlines.py +++ b/scripts/scil_tractogram_count_streamlines.py @@ -44,7 +44,7 @@ def main(): assert_inputs_exist(parser, args.in_tractogram) bundle_name, _ = os.path.splitext(os.path.basename(args.in_tractogram)) - count = int(lazy_streamlines_count(args.in_tractogram)) + count = lazy_streamlines_count(args.in_tractogram) if args.print_count_alone: print(count) diff --git a/scripts/scil_tractogram_print_info.py b/scripts/scil_tractogram_print_info.py index f093b26d3..bdea51b95 100755 --- a/scripts/scil_tractogram_print_info.py +++ b/scripts/scil_tractogram_print_info.py @@ -60,7 +60,8 @@ def main(): steps = np.hstack(steps) print(json.dumps( - {'min_length_mm': float(np.min(lengths_mm)), + {'number_streamlines': int(len(sft)), + 'min_length_mm': float(np.min(lengths_mm)), 'mean_length_mm': float(np.mean(lengths_mm)), 'max_length_mm': float(np.max(lengths_mm)), 'std_length_mm': float(np.std(lengths_mm)),