21
21
22
22
def min_dist_to_centroid (bundle_pts , centroid_pts , nb_pts ):
23
23
"""
24
- Compute minimal distance to centroids
24
+ Compute minimal distance between two sets of 3D points.
25
+ The 3D points are expected to be in the same space.
26
+
27
+ Typically the bundle_pts will be voxel indices (from argwhere) and the
28
+ centroid_pts will be the 3D positions of a single streamline.
25
29
26
30
Parameters
27
31
----------
@@ -45,7 +49,27 @@ def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts):
45
49
46
50
47
51
def associate_labels (target_sft , min_label = 1 , max_label = 20 ):
48
- # DOCSTRING
52
+ """
53
+ Associate labels to the streamlines in a target SFT using their lengths.
54
+ Even if unequal distance between points, the labels are interpolated
55
+ linearly so all the points are labeled according to their position.
56
+
57
+ min and max labels are used in case there is a cut in the bundle.
58
+
59
+ Parameters:
60
+ ----------
61
+ target_sft: StatefulTractogram
62
+ The target SFT to label, streamlines can be in any space.
63
+ min_label: int
64
+ Minimum label to use.
65
+ max_label: int
66
+ Maximum label to use.
67
+
68
+ Returns:
69
+ -------
70
+ Array: np.uint16
71
+ Labels for each point along the streamlines.
72
+ """
49
73
50
74
curr_ind = 0
51
75
target_labels = np .zeros (target_sft .streamlines ._data .shape [0 ],
@@ -58,19 +82,28 @@ def associate_labels(target_sft, min_label=1, max_label=20):
58
82
curr_labels = np .round (curr_labels )
59
83
target_labels [curr_ind :curr_ind + len (streamline )] = curr_labels
60
84
curr_ind += len (streamline )
61
-
85
+
62
86
return target_labels , target_sft .streamlines ._data
63
87
64
88
65
89
def find_medoid (points , max_points = 10000 ):
66
90
"""
67
- Find the medoid among a set of points.
91
+ Find the medoid among a set of points. A medoid is a point that minimizes
92
+ the sum of the distances to all other points. Unlike a barycenter, the
93
+ medoid is guaranteed to be one of the points in the set.
68
94
69
95
Parameters:
70
- points (ndarray): Points in N-dimensional space.
96
+ ----------
97
+ points: ndarray
98
+ An array of 3D coordinates.
99
+ max_points: int
100
+ Maximum number of points to use for the computation (will randomly
101
+ select points if the number of points is greater than max_points).
71
102
72
103
Returns:
73
- ndarray: Coordinates of the medoid.
104
+ -------
105
+ np.array:
106
+ The 3D coordinates of the medoid.
74
107
"""
75
108
if len (points ) > max_points :
76
109
selected_indices = np .random .choice (len (points ), max_points ,
@@ -82,17 +115,26 @@ def find_medoid(points, max_points=10000):
82
115
return points [medoid_idx ]
83
116
84
117
85
- def compute_labels_map_barycenters (labels_map , is_euclidian = False , nb_pts = False ):
118
+ def compute_labels_map_barycenters (labels_map , is_euclidian = False ,
119
+ nb_pts = False ):
86
120
"""
87
121
Compute the barycenter for each label in a 3D NumPy array by maximizing
88
122
the distance to the boundary.
89
123
90
124
Parameters:
91
- labels_map (ndarray): The 3D array containing labels from 1-nb_pts.
125
+ ----------
126
+ labels_map: (ndarray)
127
+ The 3D array containing labels from 1-nb_pts.
92
128
euclidian (bool): If True, the barycenter is the mean of the points
129
+ in the mask. If False, the barycenter is the medoid of the points in
130
+ the mask.
131
+ nb_pts: int
132
+ Number of points to use for computing barycenters.
93
133
94
134
Returns:
95
- ndarray: An array of size (nb_pts, 3) containing the barycenter
135
+ -------
136
+ ndarray:
137
+ An array of size (nb_pts, 3) containing the barycenter
96
138
for each label.
97
139
"""
98
140
labels = np .arange (1 , nb_pts + 1 ) if nb_pts else np .unique (labels_map )[1 :]
@@ -127,11 +169,15 @@ def masked_manhattan_distance(mask, target_positions):
127
169
positions, without stepping out of the mask.
128
170
129
171
Parameters:
130
- mask (ndarray): A binary 3D array representing the mask.
172
+ ----------
173
+ mask (ndarray):
174
+ A binary 3D array representing the mask.
131
175
target_positions (list): A list of target positions within the mask.
132
176
133
177
Returns:
134
- ndarray: A 3D array of the same shape as the mask, containing the
178
+ -------
179
+ ndarray:
180
+ A 3D array of the same shape as the mask, containing the
135
181
Manhattan distances.
136
182
"""
137
183
# Initialize distance array with infinite values
@@ -172,6 +218,7 @@ def compute_distance_map(labels_map, binary_mask, nb_pts, use_manhattan=False):
172
218
Computes the distance map for each label in the labels_map.
173
219
174
220
Parameters:
221
+ ----------
175
222
labels_map (numpy.ndarray):
176
223
A 3D array representing the labels map.
177
224
binary_mask (numpy.ndarray):
@@ -182,6 +229,7 @@ def compute_distance_map(labels_map, binary_mask, nb_pts, use_manhattan=False):
182
229
If True, use the Manhattan distance instead of the Euclidian distance.
183
230
184
231
Returns:
232
+ -------
185
233
numpy.ndarray: A 3D array representing the distance map.
186
234
"""
187
235
barycenters = compute_labels_map_barycenters (labels_map ,
@@ -240,17 +288,24 @@ def compute_distance_map(labels_map, binary_mask, nb_pts, use_manhattan=False):
240
288
241
289
def correct_labels_jump (labels_map , streamlines , nb_pts ):
242
290
"""
243
- Computes the distance map for each label in the labels_map.
291
+ Correct the labels jump in the labels map by cutting the streamlines
292
+ where the jump is detected and keeping the longest chunk.
293
+
294
+ This avoid loops in the labels map and ensure that the labels are
295
+ consistent along the streamlines.
244
296
245
297
Parameters:
246
- labels_map (numpy.ndarray):
298
+ ----------
299
+ labels_map (ndarray):
247
300
A 3D array representing the labels map.
248
- streamlines:
301
+ streamlines (ArraySequence):
302
+ The streamlines used to compute the labels map.
249
303
nb_pts (int):
250
304
Number of points to use for computing barycenters.
251
305
252
306
Returns:
253
- numpy.ndarray: A 3D array representing the distance map.
307
+ -------
308
+ ndarray: A 3D array representing the corrected labels map.
254
309
"""
255
310
256
311
labels_data = ndi .map_coordinates (labels_map , streamlines ._data .T - 0.5 ,
@@ -280,7 +335,7 @@ def correct_labels_jump(labels_map, streamlines, nb_pts):
280
335
is_flip = True
281
336
282
337
# Find jumps, cut them and find the longest
283
- max_jump = max (nb_pts // 4 , 1 )
338
+ max_jump = max (nb_pts // 4 , 1 )
284
339
if len (np .argwhere (np .abs (gradient ) > max_jump )) > 0 :
285
340
pos_jump = np .where (np .abs (gradient ) > max_jump )[0 ] + 1
286
341
split_chunk = np .split (curr_labels ,
@@ -307,15 +362,15 @@ def correct_labels_jump(labels_map, streamlines, nb_pts):
307
362
modified_binary_mask = compute_tract_counts_map (final_streamlines ,
308
363
binary_mask .shape )
309
364
modified_binary_mask [modified_binary_mask > 0 ] = 1
310
-
365
+
311
366
# Compute the KDTree for the new streamlines to find the closest
312
367
# labels for each voxel
313
368
kd_tree = KDTree (final_streamlines ._data - 0.5 )
314
369
315
370
indices = np .array (np .nonzero (modified_binary_mask ), dtype = int ).T
316
371
labels_map = np .zeros (labels_map .shape , dtype = np .uint16 )
317
372
neighbor_ids = kd_tree .query_ball_point (indices , r = 1.0 )
318
-
373
+
319
374
for ind , neighbor_id in zip (indices , neighbor_ids ):
320
375
if len (neighbor_id ) == 0 :
321
376
continue
@@ -338,7 +393,40 @@ def correct_labels_jump(labels_map, streamlines, nb_pts):
338
393
def subdivide_bundles (sft , sft_centroid , binary_mask , nb_pts ,
339
394
method = 'centerline' , fix_jumps = True ):
340
395
"""
341
-
396
+ Function to divide a bundle into multiple section along its length.
397
+ The resulting labels map is based on the binary_mask, but the streamlines
398
+ are required for a few internal corrections (for consistency).
399
+
400
+ The default is to use the euclidian/centerline method, which is fast and
401
+ works well for most cases.
402
+
403
+ The hyperplane method allows for more complex shapes and to split the bundles
404
+ into subsections that follow the geometry of each kind of bundle.
405
+ However, this method is slower and requires extra quality control to ensure
406
+ that the labels are correct. This method requires a centroid file that
407
+ contains multiple streamlines.
408
+
409
+ Parameters:
410
+ ----------
411
+ sft (StatefulTractogram):
412
+ Represent the streamlines to be subdivided, streamlines representation
413
+ is useful fro the fix_jump parameter.
414
+ sft_centroid (StatefulTractogram):
415
+ Centroids used as a reference for subdivision.
416
+ binary_mask (ndarray):
417
+ Mask to be converted to a label mask
418
+ nb_pts (int):
419
+ Number of subdivision along streamlines' length
420
+ method (str):
421
+ Choice between centerline or hyperplane for subdivision
422
+ fix_jumps (bool):
423
+ Run the correction for streamlines to reduce big transition along
424
+ its length.
425
+
426
+ Returns:
427
+ -------
428
+ ndarray:
429
+ A 3D array representing the labels map.
342
430
"""
343
431
sft .to_vox ()
344
432
sft_centroid .to_vox ()
@@ -357,10 +445,10 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
357
445
358
446
indices = np .array (np .nonzero (binary_mask ), dtype = int ).T
359
447
labels = min_dist_to_centroid (indices ,
360
- sft_centroid [0 ].streamlines ._data ,
361
- nb_pts = nb_pts )
448
+ sft_centroid [0 ].streamlines ._data ,
449
+ nb_pts = nb_pts )
362
450
logging .debug ('Computed labels using the euclidian method '
363
- f'in { round (time .time () - timer , 3 )} seconds' )
451
+ f'in { round (time .time () - timer , 3 )} seconds' )
364
452
min_label , max_label = labels .min (), labels .max ()
365
453
366
454
if method == 'centerline' :
@@ -370,9 +458,10 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
370
458
min_label , max_label = labels .min (), labels .max ()
371
459
del labels , indices
372
460
logging .debug ('Computing Labels using the hyperplane method.\n '
373
- '\t This can take a while...' )
461
+ '\t This can take a while...' )
374
462
# Select 2000 elements from the SFTs to train the classifier
375
- streamlines_length = [length (streamline ) for streamline in sft .streamlines ]
463
+ streamlines_length = [length (streamline )
464
+ for streamline in sft .streamlines ]
376
465
random_indices = np .random .choice (len (sft .streamlines ), 2000 )
377
466
tmp_sft = resample_streamlines_step_size (
378
467
sft [random_indices ], np .min (streamlines_length ) / nb_pts )
@@ -389,7 +478,7 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
389
478
labels , points = labels [nn_indices ], points [nn_indices ]
390
479
391
480
logging .debug ('\t Associated labels to centroids in '
392
- f'{ round (time .time () - mini_timer , 3 )} seconds' )
481
+ f'{ round (time .time () - mini_timer , 3 )} seconds' )
393
482
394
483
# Initialize the scaler
395
484
mini_timer = time .time ()
@@ -401,7 +490,7 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
401
490
402
491
svc .fit (X = scaled_streamline_data , y = labels )
403
492
logging .debug ('\t SVC fit of training data in '
404
- f'{ round (time .time () - mini_timer , 3 )} seconds' )
493
+ f'{ round (time .time () - mini_timer , 3 )} seconds' )
405
494
406
495
# Scale the coordinates of the voxels
407
496
mini_timer = time .time ()
@@ -413,10 +502,10 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
413
502
# Predict the labels for the voxels
414
503
labels = svc .predict (X = scaled_voxel_coords )
415
504
logging .debug ('\t SVC prediction of labels in '
416
- f'{ round (time .time () - mini_timer , 3 )} seconds' )
505
+ f'{ round (time .time () - mini_timer , 3 )} seconds' )
417
506
418
507
logging .debug ('Computed labels using the hyperplane method '
419
- f'in { round (time .time () - timer , 3 )} seconds' )
508
+ f'in { round (time .time () - timer , 3 )} seconds' )
420
509
labels_map = np .zeros (binary_mask .shape , dtype = np .uint16 )
421
510
labels_map [np .where (masked_binary_mask )] = labels
422
511
@@ -434,16 +523,14 @@ def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts,
434
523
timer = time .time ()
435
524
tmp_sft = resample_streamlines_step_size (sft , 1.0 )
436
525
labels_map = correct_labels_jump (labels_map , tmp_sft .streamlines ,
437
- nb_pts - 2 )
526
+ nb_pts - 2 )
438
527
logging .debug ('Corrected labels jump in '
439
- f'{ round (time .time () - timer , 3 )} seconds' )
440
-
528
+ f'{ round (time .time () - timer , 3 )} seconds' )
441
529
442
530
if endpoints_extended :
443
531
labels_map [labels_map == nb_pts ] = nb_pts - 1
444
532
labels_map [labels_map == 1 ] = 2
445
533
labels_map [labels_map > 0 ] -= 1
446
534
nb_pts -= 2
447
535
448
-
449
536
return labels_map
0 commit comments