Skip to content

Commit c48d8cc

Browse files
author
frheault
committed
documentation
2 parents 3ec0991 + 23d65e5 commit c48d8cc

File tree

3 files changed

+148
-51
lines changed

3 files changed

+148
-51
lines changed

Diff for: scilpy/tractanalysis/distance_to_centroid.py

+119-32
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222
def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts):
2323
"""
24-
Compute minimal distance to centroids
24+
Compute minimal distance between two sets of 3D points.
25+
The 3D points are expected to be in the same space.
26+
27+
Typically the bundle_pts will be voxel indices (from argwhere) and the
28+
centroid_pts will be the 3D positions of a single streamline.
2529
2630
Parameters
2731
----------
@@ -45,7 +49,27 @@ def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts):
4549

4650

4751
def associate_labels(target_sft, min_label=1, max_label=20):
48-
# DOCSTRING
52+
"""
53+
Associate labels to the streamlines in a target SFT using their lengths.
54+
Even if unequal distance between points, the labels are interpolated
55+
linearly so all the points are labeled according to their position.
56+
57+
min and max labels are used in case there is a cut in the bundle.
58+
59+
Parameters:
60+
----------
61+
target_sft: StatefulTractogram
62+
The target SFT to label, streamlines can be in any space.
63+
min_label: int
64+
Minimum label to use.
65+
max_label: int
66+
Maximum label to use.
67+
68+
Returns:
69+
-------
70+
Array: np.uint16
71+
Labels for each point along the streamlines.
72+
"""
4973

5074
curr_ind = 0
5175
target_labels = np.zeros(target_sft.streamlines._data.shape[0],
@@ -58,19 +82,28 @@ def associate_labels(target_sft, min_label=1, max_label=20):
5882
curr_labels = np.round(curr_labels)
5983
target_labels[curr_ind:curr_ind+len(streamline)] = curr_labels
6084
curr_ind += len(streamline)
61-
85+
6286
return target_labels, target_sft.streamlines._data
6387

6488

6589
def find_medoid(points, max_points=10000):
6690
"""
67-
Find the medoid among a set of points.
91+
Find the medoid among a set of points. A medoid is a point that minimizes
92+
the sum of the distances to all other points. Unlike a barycenter, the
93+
medoid is guaranteed to be one of the points in the set.
6894
6995
Parameters:
70-
points (ndarray): Points in N-dimensional space.
96+
----------
97+
points: ndarray
98+
An array of 3D coordinates.
99+
max_points: int
100+
Maximum number of points to use for the computation (will randomly
101+
select points if the number of points is greater than max_points).
71102
72103
Returns:
73-
ndarray: Coordinates of the medoid.
104+
-------
105+
np.array:
106+
The 3D coordinates of the medoid.
74107
"""
75108
if len(points) > max_points:
76109
selected_indices = np.random.choice(len(points), max_points,
@@ -82,17 +115,26 @@ def find_medoid(points, max_points=10000):
82115
return points[medoid_idx]
83116

84117

85-
def compute_labels_map_barycenters(labels_map, is_euclidian=False, nb_pts=False):
118+
def compute_labels_map_barycenters(labels_map, is_euclidian=False,
119+
nb_pts=False):
86120
"""
87121
Compute the barycenter for each label in a 3D NumPy array by maximizing
88122
the distance to the boundary.
89123
90124
Parameters:
91-
labels_map (ndarray): The 3D array containing labels from 1-nb_pts.
125+
----------
126+
labels_map: (ndarray)
127+
The 3D array containing labels from 1-nb_pts.
92128
euclidian (bool): If True, the barycenter is the mean of the points
129+
in the mask. If False, the barycenter is the medoid of the points in
130+
the mask.
131+
nb_pts: int
132+
Number of points to use for computing barycenters.
93133
94134
Returns:
95-
ndarray: An array of size (nb_pts, 3) containing the barycenter
135+
-------
136+
ndarray:
137+
An array of size (nb_pts, 3) containing the barycenter
96138
for each label.
97139
"""
98140
labels = np.arange(1, nb_pts+1) if nb_pts else np.unique(labels_map)[1:]
@@ -127,11 +169,15 @@ def masked_manhattan_distance(mask, target_positions):
127169
positions, without stepping out of the mask.
128170
129171
Parameters:
130-
mask (ndarray): A binary 3D array representing the mask.
172+
----------
173+
mask (ndarray):
174+
A binary 3D array representing the mask.
131175
target_positions (list): A list of target positions within the mask.
132176
133177
Returns:
134-
ndarray: A 3D array of the same shape as the mask, containing the
178+
-------
179+
ndarray:
180+
A 3D array of the same shape as the mask, containing the
135181
Manhattan distances.
136182
"""
137183
# Initialize distance array with infinite values
@@ -172,6 +218,7 @@ def compute_distance_map(labels_map, binary_mask, nb_pts, use_manhattan=False):
172218
Computes the distance map for each label in the labels_map.
173219
174220
Parameters:
221+
----------
175222
labels_map (numpy.ndarray):
176223
A 3D array representing the labels map.
177224
binary_mask (numpy.ndarray):
@@ -182,6 +229,7 @@ def compute_distance_map(labels_map, binary_mask, nb_pts, use_manhattan=False):
182229
If True, use the Manhattan distance instead of the Euclidian distance.
183230
184231
Returns:
232+
-------
185233
numpy.ndarray: A 3D array representing the distance map.
186234
"""
187235
barycenters = compute_labels_map_barycenters(labels_map,
@@ -240,17 +288,24 @@ def compute_distance_map(labels_map, binary_mask, nb_pts, use_manhattan=False):
240288

241289
def correct_labels_jump(labels_map, streamlines, nb_pts):
242290
"""
243-
Computes the distance map for each label in the labels_map.
291+
Correct the labels jump in the labels map by cutting the streamlines
292+
where the jump is detected and keeping the longest chunk.
293+
294+
This avoid loops in the labels map and ensure that the labels are
295+
consistent along the streamlines.
244296
245297
Parameters:
246-
labels_map (numpy.ndarray):
298+
----------
299+
labels_map (ndarray):
247300
A 3D array representing the labels map.
248-
streamlines:
301+
streamlines (ArraySequence):
302+
The streamlines used to compute the labels map.
249303
nb_pts (int):
250304
Number of points to use for computing barycenters.
251305
252306
Returns:
253-
numpy.ndarray: A 3D array representing the distance map.
307+
-------
308+
ndarray: A 3D array representing the corrected labels map.
254309
"""
255310

256311
labels_data = ndi.map_coordinates(labels_map, streamlines._data.T - 0.5,
@@ -280,7 +335,7 @@ def correct_labels_jump(labels_map, streamlines, nb_pts):
280335
is_flip = True
281336

282337
# Find jumps, cut them and find the longest
283-
max_jump = max(nb_pts // 4 , 1)
338+
max_jump = max(nb_pts // 4, 1)
284339
if len(np.argwhere(np.abs(gradient) > max_jump)) > 0:
285340
pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1
286341
split_chunk = np.split(curr_labels,
@@ -307,15 +362,15 @@ def correct_labels_jump(labels_map, streamlines, nb_pts):
307362
modified_binary_mask = compute_tract_counts_map(final_streamlines,
308363
binary_mask.shape)
309364
modified_binary_mask[modified_binary_mask > 0] = 1
310-
365+
311366
# Compute the KDTree for the new streamlines to find the closest
312367
# labels for each voxel
313368
kd_tree = KDTree(final_streamlines._data - 0.5)
314369

315370
indices = np.array(np.nonzero(modified_binary_mask), dtype=int).T
316371
labels_map = np.zeros(labels_map.shape, dtype=np.uint16)
317372
neighbor_ids = kd_tree.query_ball_point(indices, r=1.0)
318-
373+
319374
for ind, neighbor_id in zip(indices, neighbor_ids):
320375
if len(neighbor_id) == 0:
321376
continue
@@ -338,7 +393,40 @@ def correct_labels_jump(labels_map, streamlines, nb_pts):
338393
def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
339394
method='centerline', fix_jumps=True):
340395
"""
341-
396+
Function to divide a bundle into multiple section along its length.
397+
The resulting labels map is based on the binary_mask, but the streamlines
398+
are required for a few internal corrections (for consistency).
399+
400+
The default is to use the euclidian/centerline method, which is fast and
401+
works well for most cases.
402+
403+
The hyperplane method allows for more complex shapes and to split the bundles
404+
into subsections that follow the geometry of each kind of bundle.
405+
However, this method is slower and requires extra quality control to ensure
406+
that the labels are correct. This method requires a centroid file that
407+
contains multiple streamlines.
408+
409+
Parameters:
410+
----------
411+
sft (StatefulTractogram):
412+
Represent the streamlines to be subdivided, streamlines representation
413+
is useful fro the fix_jump parameter.
414+
sft_centroid (StatefulTractogram):
415+
Centroids used as a reference for subdivision.
416+
binary_mask (ndarray):
417+
Mask to be converted to a label mask
418+
nb_pts (int):
419+
Number of subdivision along streamlines' length
420+
method (str):
421+
Choice between centerline or hyperplane for subdivision
422+
fix_jumps (bool):
423+
Run the correction for streamlines to reduce big transition along
424+
its length.
425+
426+
Returns:
427+
-------
428+
ndarray:
429+
A 3D array representing the labels map.
342430
"""
343431
sft.to_vox()
344432
sft_centroid.to_vox()
@@ -357,10 +445,10 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
357445

358446
indices = np.array(np.nonzero(binary_mask), dtype=int).T
359447
labels = min_dist_to_centroid(indices,
360-
sft_centroid[0].streamlines._data,
361-
nb_pts=nb_pts)
448+
sft_centroid[0].streamlines._data,
449+
nb_pts=nb_pts)
362450
logging.debug('Computed labels using the euclidian method '
363-
f'in {round(time.time() - timer, 3)} seconds')
451+
f'in {round(time.time() - timer, 3)} seconds')
364452
min_label, max_label = labels.min(), labels.max()
365453

366454
if method == 'centerline':
@@ -370,9 +458,10 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
370458
min_label, max_label = labels.min(), labels.max()
371459
del labels, indices
372460
logging.debug('Computing Labels using the hyperplane method.\n'
373-
'\tThis can take a while...')
461+
'\tThis can take a while...')
374462
# Select 2000 elements from the SFTs to train the classifier
375-
streamlines_length = [length(streamline) for streamline in sft.streamlines]
463+
streamlines_length = [length(streamline)
464+
for streamline in sft.streamlines]
376465
random_indices = np.random.choice(len(sft.streamlines), 2000)
377466
tmp_sft = resample_streamlines_step_size(
378467
sft[random_indices], np.min(streamlines_length) / nb_pts)
@@ -389,7 +478,7 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
389478
labels, points = labels[nn_indices], points[nn_indices]
390479

391480
logging.debug('\tAssociated labels to centroids in '
392-
f'{round(time.time() - mini_timer, 3)} seconds')
481+
f'{round(time.time() - mini_timer, 3)} seconds')
393482

394483
# Initialize the scaler
395484
mini_timer = time.time()
@@ -401,7 +490,7 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
401490

402491
svc.fit(X=scaled_streamline_data, y=labels)
403492
logging.debug('\tSVC fit of training data in '
404-
f'{round(time.time() - mini_timer, 3)} seconds')
493+
f'{round(time.time() - mini_timer, 3)} seconds')
405494

406495
# Scale the coordinates of the voxels
407496
mini_timer = time.time()
@@ -413,10 +502,10 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
413502
# Predict the labels for the voxels
414503
labels = svc.predict(X=scaled_voxel_coords)
415504
logging.debug('\tSVC prediction of labels in '
416-
f'{round(time.time() - mini_timer, 3)} seconds')
505+
f'{round(time.time() - mini_timer, 3)} seconds')
417506

418507
logging.debug('Computed labels using the hyperplane method '
419-
f'in {round(time.time() - timer, 3)} seconds')
508+
f'in {round(time.time() - timer, 3)} seconds')
420509
labels_map = np.zeros(binary_mask.shape, dtype=np.uint16)
421510
labels_map[np.where(masked_binary_mask)] = labels
422511

@@ -434,16 +523,14 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
434523
timer = time.time()
435524
tmp_sft = resample_streamlines_step_size(sft, 1.0)
436525
labels_map = correct_labels_jump(labels_map, tmp_sft.streamlines,
437-
nb_pts - 2)
526+
nb_pts - 2)
438527
logging.debug('Corrected labels jump in '
439-
f'{round(time.time() - timer, 3)} seconds')
440-
528+
f'{round(time.time() - timer, 3)} seconds')
441529

442530
if endpoints_extended:
443531
labels_map[labels_map == nb_pts] = nb_pts - 1
444532
labels_map[labels_map == 1] = 2
445533
labels_map[labels_map > 0] -= 1
446534
nb_pts -= 2
447535

448-
449536
return labels_map

0 commit comments

Comments
 (0)