Skip to content

Commit

Permalink
Add more missing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Sep 11, 2024
1 parent a1a2284 commit b07cc06
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
4 changes: 2 additions & 2 deletions scilpy/tractanalysis/connectivity_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
compute_streamline_segment
from scilpy.tractograms.streamline_operations import \
(remove_loops as perform_remove_loops,
remove_shap_turns_qb,
remove_sharp_turns_qb,
remove_streamlines_with_overlapping_points, filter_streamlines_by_length)


Expand Down Expand Up @@ -341,7 +341,7 @@ def construct_hdf5_from_connectivity(
if remove_curv_dev:
logging.debug("- Step 4: Removing sharp turns (Qb threshold: {})"
.format(curv_qb_distance))
no_qb_curv_ids = remove_shap_turns_qb(
no_qb_curv_ids = remove_sharp_turns_qb(
current_sft.streamlines, qb_threshold=curv_qb_distance)
qb_curv_ids = np.setdiff1d(np.arange(len(current_sft)),
no_qb_curv_ids)
Expand Down
7 changes: 4 additions & 3 deletions scilpy/tractograms/streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def remove_single_point_streamlines(sft):

def remove_overlapping_points_streamlines(sft, threshold=0.001):
"""
Remove overlapping points from streamlines in a StatefulTractogram.
Remove overlapping points from streamlines in a StatefulTractogram, i.e.
points forming a segment of length < threshold in a given streamline.
Parameters
----------
Expand Down Expand Up @@ -844,7 +845,7 @@ def remove_loops(streamlines, max_angle, num_processes=1):
return ids, streamlines_clean


def remove_shap_turns_qb(streamlines, qb_threshold=15.0, qb_seed=0):
def remove_sharp_turns_qb(streamlines, qb_threshold=15.0, qb_seed=0):
"""
Remove sharp turns from a list of streamlines. Should only be used on
bundled streamlines, not on whole-brain tractograms.
Expand Down Expand Up @@ -918,7 +919,7 @@ def remove_loops_and_sharp_turns(streamlines, max_angle, qb_threshold=None,
num_processes)

if qb_threshold is not None:
ids = remove_shap_turns_qb(streamlines_clean, qb_threshold, qb_seed)
ids = remove_sharp_turns_qb(streamlines_clean, qb_threshold, qb_seed)
return ids


Expand Down
51 changes: 47 additions & 4 deletions scilpy/tractograms/tests/test_streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import tempfile

import nibabel as nib
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest
Expand All @@ -12,17 +11,21 @@
from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.tractograms.streamline_operations import (
compress_sft,
cut_invalid_streamlines,
filter_streamlines_by_length,
filter_streamlines_by_total_length_per_dim,
get_angles,
get_streamlines_as_linspaces,
resample_streamlines_num_points,
resample_streamlines_step_size,
smooth_line_gaussian,
smooth_line_spline,
parallel_transport_streamline, get_angles, get_streamlines_as_linspaces,
compress_sft, cut_invalid_streamlines)
parallel_transport_streamline,
remove_overlapping_points_streamlines,
remove_single_point_streamlines)
from scilpy.tractograms.tractogram_operations import concatenate_sft


fetch_data(get_testing_files_dict(), keys=['tractograms.zip'])
tmp_dir = tempfile.TemporaryDirectory()

Expand Down Expand Up @@ -100,6 +103,30 @@ def test_cut_invalid_streamlines():
assert nb == 1


def test_remove_single_point_streamlines():
sft = load_tractogram(in_short_sft, in_ref)

# Adding a one-point streamline
sft.streamlines.append([[7, 7, 7]])
new_sft = remove_single_point_streamlines(sft)
assert len(new_sft) == len(sft) - 1


def test_remove_overlapping_points_streamlines():
sft = load_tractogram(in_short_sft, in_ref)

fake_line = np.asarray([[3, 3, 3],
[4, 4, 4],
[5, 5, 5],
[5, 5, 5.00000001]], dtype=float)
sft.streamlines.append(fake_line)

new_sft = remove_overlapping_points_streamlines(sft)
assert len(new_sft.streamlines[-1]) == len(sft.streamlines[-1]) - 1
assert np.all([len(new_sft.streamlines[i]) == len(sft.streamlines[i]) for
i in range(len(sft) - 1)])


def test_filter_streamlines_by_length():
long_sft = load_tractogram(in_long_sft, in_ref)
mid_sft = load_tractogram(in_mid_sft, in_ref)
Expand Down Expand Up @@ -382,3 +409,19 @@ def test_parallel_transport_streamline():
decimal=4)
assert [len(s) for s in pt_streamlines] == [130] * 20
assert len(pt_streamlines) == 20


def test_remove_loops():
# toDO
# Coverage will not work: uses multi-processing
pass


def test_remove_sharp_turns_qb():
# toDO
pass


def test_remove_loops_and_sharp_turns():
# ok. Just a combination of the two previous functions.
pass

0 comments on commit b07cc06

Please sign in to comment.