Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tractograms: more unit tests #1027

Merged
merged 10 commits into from
Sep 20, 2024
15 changes: 15 additions & 0 deletions scilpy/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-

import numpy as np


def nan_array_equal(a, b):
a = np.asarray(a)
b = np.asarray(b)

nan_a = np.argwhere(np.isnan(a))
nan_b = np.argwhere(np.isnan(a))

a = a[~np.isnan(a)]
b = b[~np.isnan(b)]
return np.array_equal(a, b) and np.array_equal(nan_a, nan_b)
4 changes: 2 additions & 2 deletions scilpy/tractanalysis/connectivity_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
compute_streamline_segment
from scilpy.tractograms.streamline_operations import \
(remove_loops as perform_remove_loops,
remove_shap_turns_qb,
remove_sharp_turns_qb,
remove_streamlines_with_overlapping_points, filter_streamlines_by_length)


Expand Down Expand Up @@ -341,7 +341,7 @@ def construct_hdf5_from_connectivity(
if remove_curv_dev:
logging.debug("- Step 4: Removing sharp turns (Qb threshold: {})"
.format(curv_qb_distance))
no_qb_curv_ids = remove_shap_turns_qb(
no_qb_curv_ids = remove_sharp_turns_qb(
current_sft.streamlines, qb_threshold=curv_qb_distance)
qb_curv_ids = np.setdiff1d(np.arange(len(current_sft)),
no_qb_curv_ids)
Expand Down
6 changes: 2 additions & 4 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,8 @@ def perform_operation_dpp_to_dps(op_name, sft, dpp_name, endpoints_only=False):
if endpoints_only:
new_data_per_streamline = []
for s in sft.data_per_point[dpp_name]:
start = s[0]
end = s[-1]
concat = np.concatenate((start[:], end[:]))
new_data_per_streamline.append(call_op(concat))
fake_s = np.asarray([s[0], s[-1]])
new_data_per_streamline.append(call_op(fake_s))
else:
new_data_per_streamline = []
for s in sft.data_per_point[dpp_name]:
Expand Down
7 changes: 5 additions & 2 deletions scilpy/tractograms/lazy_tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ def lazy_streamlines_count(in_tractogram_path):

tractogram_file = nib.streamlines.load(in_tractogram_path,
lazy_load=True)
return tractogram_file.header[key]
return int(tractogram_file.header[key])


def lazy_concatenate(in_tractograms, out_ext):
"""
Concatenates tractograms, if they can be concatenated. Headers must be
compatible.

Parameters
----------
in_tractograms: list
Expand All @@ -45,7 +48,7 @@ def lazy_concatenate(in_tractograms, out_ext):

Returns
-------
out_tractogram: Lazy tractogram
out_tractogram: LazyTractogram
The concatenated data
header: nibabel header or None
Depending on the data type.
Expand Down
70 changes: 48 additions & 22 deletions scilpy/tractograms/streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,36 +92,52 @@ def _get_point_on_line(first_point, second_point, vox_lower_corner):
return first_point + ray * (t0 + t1) / 2.


def get_angles(sft):
"""Color streamlines according to their length.
def get_angles(sft, degrees=True, add_zeros=False):
"""
Returns the angle between each segment of the streamlines.

Parameters
----------
sft: StatefulTractogram
The tractogram.
The tractogram, with N streamlines.
degrees: bool
If True, returns angles in degree. Else, in radian.
add_zeros: bool
For a streamline of length M, there are M-1 segments, and M-2 angles.
If add_zeros is set to True, a 0 angle is added at both ends of the
returned values, to get M values.

Returns
-------
angles: list[np.ndarray]
The angles per streamline, in degree.
List of N numpy arrays. The angles per streamline, in degree.
"""
angles = []
for i in range(len(sft.streamlines)):
dirs = np.diff(sft.streamlines[i], axis=0)
dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True)
cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1)

# Resolve numerical instability
cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0)
line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0]
angles.extend(line_angles)
line_angles = list(np.arccos(cos_angles))

if add_zeros:
line_angles = [0.0] + line_angles + [0.0]

if degrees:
line_angles = np.rad2deg(line_angles)

angles = np.rad2deg(angles)
angles.append(line_angles)

return angles


def get_values_along_length(sft):
"""Get the streamlines' coordinate positions according to their length.
def get_streamlines_as_linspaces(sft):
"""
For each streamline, returns a list of M values ranging between 0 and 1,
where M is the number of points in the streamline. This gives the position
of each coordinate per respect to the streamline's length.

Parameters
----------
Expand All @@ -134,8 +150,8 @@ def get_values_along_length(sft):
For each streamline, the linear distribution of its length.
"""
positions = []
for i in range(len(sft.streamlines)):
positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i]))))
for s in sft.streamlines:
positions.append(list(np.linspace(0, 1, len(s))))

return positions

Expand Down Expand Up @@ -208,6 +224,8 @@ def cut_invalid_streamlines(sft, epsilon=0.001):
----------
sft: StatefulTractogram
The sft to remove invalid points from.
epsilon: float
Error allowed when verifying the bounding box.

Returns
-------
Expand All @@ -227,7 +245,7 @@ def cut_invalid_streamlines(sft, epsilon=0.001):
sft.to_corner()

copy_sft = copy.deepcopy(sft)
indices_to_remove, _ = copy_sft.remove_invalid_streamlines()
indices_to_cut, _ = copy_sft.remove_invalid_streamlines()

new_streamlines = []
new_data_per_point = {}
Expand All @@ -239,21 +257,27 @@ def cut_invalid_streamlines(sft, epsilon=0.001):

cutting_counter = 0
for ind in range(len(sft.streamlines)):
# No reason to try to cut if all points are within the volume
if ind in indices_to_remove:
best_pos = [0, 0]
cur_pos = [0, 0]
if ind in indices_to_cut:
# This streamline was detected as invalid

best_pos = [0, 0] # First and last valid points of longest segment
cur_pos = [0, 0] # First and last valid points of current segment
for pos, point in enumerate(sft.streamlines[ind]):
if (point < epsilon).any() or \
(point >= sft.dimensions - epsilon).any():
# The coordinate is < 0 or > box. Starting new segment
cur_pos = [pos+1, pos+1]
if cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]:
elif cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]:
# We found a longer good segment.
best_pos = cur_pos

# Ready to check next point.
cur_pos[1] += 1

# Appending the longest segment to the list of streamlines
if not best_pos == [0, 0]:
new_streamlines.append(
sft.streamlines[ind][best_pos[0]:best_pos[1]-1])
sft.streamlines[ind][best_pos[0]:best_pos[1]])
cutting_counter += 1
for key in sft.data_per_streamline.keys():
new_data_per_streamline[key].append(
Expand All @@ -263,8 +287,9 @@ def cut_invalid_streamlines(sft, epsilon=0.001):
sft.data_per_point[key][ind][
best_pos[0]:best_pos[1]-1])
else:
logging.warning('Streamlines entirely out of the volume.')
logging.warning('Streamline entirely out of the volume.')
else:
# No reason to try to cut if all points are within the volume
new_streamlines.append(sft.streamlines[ind])
for key in sft.data_per_streamline.keys():
new_data_per_streamline[key].append(
Expand Down Expand Up @@ -310,7 +335,8 @@ def remove_single_point_streamlines(sft):

def remove_overlapping_points_streamlines(sft, threshold=0.001):
"""
Remove overlapping points from streamlines in a StatefulTractogram.
Remove overlapping points from streamlines in a StatefulTractogram, i.e.
points forming a segment of length < threshold in a given streamline.

Parameters
----------
Expand Down Expand Up @@ -819,7 +845,7 @@ def remove_loops(streamlines, max_angle, num_processes=1):
return ids, streamlines_clean


def remove_shap_turns_qb(streamlines, qb_threshold=15.0, qb_seed=0):
def remove_sharp_turns_qb(streamlines, qb_threshold=15.0, qb_seed=0):
"""
Remove sharp turns from a list of streamlines. Should only be used on
bundled streamlines, not on whole-brain tractograms.
Expand Down Expand Up @@ -893,7 +919,7 @@ def remove_loops_and_sharp_turns(streamlines, max_angle, qb_threshold=None,
num_processes)

if qb_threshold is not None:
ids = remove_shap_turns_qb(streamlines_clean, qb_threshold, qb_seed)
ids = remove_sharp_turns_qb(streamlines_clean, qb_threshold, qb_seed)
return ids


Expand Down
47 changes: 34 additions & 13 deletions scilpy/tractograms/tests/test_dps_and_dpp_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin

from scilpy.image.volume_space_management import DataVolume
from scilpy.tests.utils import nan_array_equal
from scilpy.tractograms.dps_and_dpp_management import (
add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines,
project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps,
Expand All @@ -26,18 +27,6 @@ def _get_small_sft():
return fake_sft


def nan_array_equal(a, b):
a = np.asarray(a)
b = np.asarray(b)

nan_a = np.argwhere(np.isnan(a))
nan_b = np.argwhere(np.isnan(a))

a = a[~np.isnan(a)]
b = b[~np.isnan(b)]
return np.array_equal(a, b) and np.array_equal(nan_a, nan_b)


def test_add_data_as_color_dpp():
lut = get_lookup_table('viridis')

Expand Down Expand Up @@ -143,13 +132,28 @@ def test_project_dpp_to_map():
fake_sft = _get_small_sft()
fake_sft.data_per_point['my_dpp'] = [[1]*3, [2]*4]

map_data = project_dpp_to_map(fake_sft, 'my_dpp', sum_lines=True)
# Average
map_data = project_dpp_to_map(fake_sft, 'my_dpp')
expected = np.zeros((3, 3, 3)) # fake_ref is 3x3x3
expected[0, 0, 0] = 1 # the 3 points of the first streamline
expected[1, 1, 1] = 2 # the 4 points of the second streamline
assert np.array_equal(map_data, expected)

# Sum
map_data = project_dpp_to_map(fake_sft, 'my_dpp', sum_lines=True)
expected = np.zeros((3, 3, 3)) # fake_ref is 3x3x3
expected[0, 0, 0] = 3 * 1 # the 3 points of the first streamline
expected[1, 1, 1] = 4 * 2 # the 4 points of the second streamline
assert np.array_equal(map_data, expected)

# Option 'endpoints_only':
map_data = project_dpp_to_map(fake_sft, 'my_dpp', sum_lines=True,
endpoints_only=True)
expected = np.zeros((3, 3, 3)) # fake_ref is 3x3x3
expected[0, 0, 0] = 2 * 1 # only 2 points of the first streamline
expected[1, 1, 1] = 2 * 2 # only 2 points of the second streamline
assert np.array_equal(map_data, expected)


def test_perform_operation_on_dpp():
fake_sft = _get_small_sft()
Expand All @@ -176,12 +180,23 @@ def test_perform_operation_on_dpp():
assert np.array_equal(dpp[0].squeeze(), [1] * 3)
assert np.array_equal(dpp[1].squeeze(), [2] * 4)

# Option 'endpoints only':
dpp = perform_operation_on_dpp('max', fake_sft, 'my_dpp',
endpoints_only=True)
assert nan_array_equal(dpp[0].squeeze(), [1.0, np.nan, 1])
assert nan_array_equal(dpp[1].squeeze(), [2.0, np.nan, 2])


def test_perform_operation_dpp_to_dps():
fake_sft = _get_small_sft()

# This fake dpp contains two values per point: [1, 0] at each point for the
# first streamline (length 3), [2, 0] for the second (length 4).
fake_sft.data_per_point['my_dpp'] = [[[1, 0]]*3,
[[2, 0]]*4]

# Operations are done separately for each value.

# Mean:
dps = perform_operation_dpp_to_dps('mean', fake_sft, 'my_dpp')
assert np.array_equal(dps[0], [1, 0])
Expand All @@ -202,6 +217,12 @@ def test_perform_operation_dpp_to_dps():
assert np.array_equal(dps[0], [1, 0])
assert np.array_equal(dps[1], [2, 0])

# Option 'endpoints_only':
dps = perform_operation_dpp_to_dps('sum', fake_sft, 'my_dpp',
endpoints_only=True)
assert np.array_equal(dps[0], [2 * 1, 0])
assert np.array_equal(dps[1], [2 * 2, 0])


def test_perform_correlation_on_endpoints():
fake_sft = _get_small_sft()
Expand Down
25 changes: 25 additions & 0 deletions scilpy/tractograms/tests/test_lazy_tractogram_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
import os

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.tractograms.lazy_tractogram_operations import \
lazy_streamlines_count, lazy_concatenate

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['tractograms.zip'])
main_path = os.path.join(SCILPY_HOME, 'tractograms', 'streamline_operations')


def test_lazy_tractogram_count():
in_file = os.path.join(main_path, 'bundle_4.tck')
nb = lazy_streamlines_count(in_file)
assert nb == 10


def test_lazy_concatenate():
in_file1 = os.path.join(main_path, 'bundle_4.tck')
in_file2 = os.path.join(main_path, 'bundle_4_cut_endpoints.tck')

out_trk, out_header = lazy_concatenate([in_file1, in_file2], '.tck')
assert len(out_trk) == 20
Loading