From 756c38516e4ccd74d36d53cb90939def05d4d3c0 Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 30 Oct 2023 15:45:24 -0400 Subject: [PATCH 01/14] working version --- requirements.txt | 2 +- .../scil_compute_bundle_voxel_label_map.py | 182 +++++++++--------- 2 files changed, 97 insertions(+), 87 deletions(-) diff --git a/requirements.txt b/requirements.txt index 85b83b0a6..b4eb13dcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,4 +44,4 @@ h5py>=2.8.0 packaging>=19.0 tqdm>=4.30.0 --e git+https://github.com/scilus/hot_dipy@1.8.0.dev0#egg=dipy \ No newline at end of file +-e git+https://github.com/scilus/hot_dipy@1.8.0.dev0#egg=dipy diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 70b81682d..973bb0db9 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -13,7 +13,7 @@ from dipy.align.streamlinear import StreamlineLinearRegistration from dipy.io.streamline import save_tractogram -from dipy.io.stateful_tractogram import StatefulTractogram, set_sft_logger_level +from dipy.io.stateful_tractogram import StatefulTractogram, set_sft_logger_level, Space from dipy.io.utils import is_header_compatible import matplotlib.pyplot as plt import nibabel as nib @@ -21,6 +21,7 @@ import numpy as np import scipy.ndimage as ndi from scipy.spatial import cKDTree +from scipy.ndimage import binary_erosion from scilpy.image.volume_math import correlation from scilpy.io.streamlines import load_tractogram_with_reference @@ -33,10 +34,13 @@ from scilpy.tractograms.streamline_and_mask_operations import \ cut_outside_of_mask_streamlines from scilpy.tractograms.streamline_operations import resample_streamlines_num_points +from scilpy.tractograms.streamline_and_mask_operations import \ + get_head_tail_density_maps from scilpy.utils.streamlines import uniformize_bundle_sft from scilpy.viz.utils import get_colormap + def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, @@ -63,6 +67,53 @@ def _build_arg_parser(): return p +from sklearn.svm import SVC +from time import time + +from collections import defaultdict + +def compute_overlap_and_mapping(small_data, full_data): + """ + Compute the overlap between labels in the small and full datasets + and generate a mapping based on maximizing this overlap. + + Parameters: + - small_data: np.ndarray, data from the smaller image + - full_data: np.ndarray, data from the full image + - unique_small: np.ndarray, unique labels in the smaller image + - unique_full: np.ndarray, unique labels in the full image + + Returns: + - dict, mapping from labels in the smaller image to labels in the full image + """ + mapping = {} + unique_small = np.unique(small_data) + unique_full = np.unique(full_data) + for label_small in unique_small: + if label_small == 0: + continue # Skip background + + overlaps = defaultdict(int) + mask_small = small_data == label_small + + for label_full in unique_full: + if label_full == 0: + continue # Skip background + + mask_full = full_data == label_full + overlap = np.sum(mask_small & mask_full) + + if overlap > 0: + overlaps[label_full] = overlap + + if overlaps: + best_match = max(overlaps, key=overlaps.get) + mapping[label_small] = best_match + else: + # If no overlap found, continue with default increasing +1 scheme + mapping[label_small] = label_small + 1 + + return mapping def main(): parser = _build_arg_parser() @@ -129,6 +180,8 @@ def main(): # Chop off some streamlines concat_sft = StatefulTractogram.from_sft([], sft_list[0]) + concat_sft.to_vox() + concat_sft.to_corner() for i in range(len(sft_list)): sft_list[i] = cut_outside_of_mask_streamlines(sft_list[i], binary_bundle) @@ -139,7 +192,8 @@ def main(): else args.nb_pts sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) - tmp_sft = resample_streamlines_num_points(concat_sft, args.nb_pts) + uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0]) + tmp_sft = resample_streamlines_num_points(concat_sft[0:1000], args.nb_pts) if not args.new_labelling: new_streamlines = sft_centroid.streamlines.copy() @@ -151,88 +205,44 @@ def main(): moving=sft_centroid.streamlines) sft_centroid.streamlines = srm.transform(sft_centroid.streamlines) - uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0]) - labels, dists = min_dist_to_centroid(concat_sft.streamlines._data, - sft_centroid.streamlines._data, - args.nb_pts) - labels += 1 # 0 means no labels - - # It is not allowed that labels jumps labels for consistency - # Streamlines should have continous labels - final_streamlines = [] - final_label = [] - final_dists = [] - curr_ind = 0 - for i, streamline in enumerate(concat_sft.streamlines): - next_ind = curr_ind + len(streamline) - curr_labels = labels[curr_ind:next_ind] - curr_dists = dists[curr_ind:next_ind] - curr_ind = next_ind - - # Flip streamlines so the labels increase (facilitate if/else) - # Should always be ordered in nextflow pipeline - gradient = np.gradient(curr_labels) - if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): - streamline = streamline[::-1] - curr_labels = curr_labels[::-1] - curr_dists = curr_dists[::-1] - - # # Find jumps, cut them and find the longest - gradient = np.ediff1d(curr_labels) - max_jump = max(args.nb_pts // 5, 1) - if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: - pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 - split_chunk = np.split(curr_labels, - pos_jump) - - max_len = 0 - max_pos = 0 - for j, chunk in enumerate(split_chunk): - if len(chunk) > max_len: - max_len = len(chunk) - max_pos = j - - curr_labels = split_chunk[max_pos] - gradient_chunk = np.ediff1d(chunk) - if len(np.unique(np.sign(gradient_chunk))) > 1: + t0 = time() + if not args.new_labelling: + labels, _ = min_dist_to_centroid(concat_sft.streamlines._data, + sft_centroid.streamlines._data, + args.nb_pts) + labels += 1 # 0 means no labels + labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + indices = np.array(np.nonzero(binary_bundle), dtype=int).T + + kd_tree = cKDTree(concat_sft.streamlines._data) + for ind in indices: + _, neighbor_ids = kd_tree.query(ind, k=5) + + if not len(neighbor_ids): continue - streamline = np.split(streamline, - pos_jump)[max_pos] - curr_dists = np.split(curr_dists, - pos_jump)[max_pos] - final_streamlines.append(streamline) - final_label.append(curr_labels) - final_dists.append(curr_dists) + labels_val = labels[neighbor_ids] - final_streamlines = ArraySequence(final_streamlines) - final_labels = ArraySequence(final_label) - final_dists = ArraySequence(final_dists) + vote = np.bincount(labels_val) + total = np.arange(np.amax(labels_val+1)) + winner = total[np.argmax(vote)] + labels_map[ind[0], ind[1], ind[2]] = winner - kd_tree = cKDTree(final_streamlines._data) - labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) - distance_map = np.zeros(binary_bundle.shape, dtype=float) - indices = np.array(np.nonzero(binary_bundle), dtype=int).T - - for ind in indices: - _, neighbor_ids = kd_tree.query(ind, k=5) + else: + svc = SVC(C=1, kernel='rbf') + labels = np.tile(np.arange(0,args.nb_pts)[::-1], len(sft_centroid)) + labels += 1 + svc.fit(X=sft_centroid.streamlines._data, y=labels) - if not len(neighbor_ids): - continue + labels_pred = svc.predict(X=np.array(np.where(binary_bundle)).T) + labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + labels_map[np.where(binary_bundle)] = labels_pred - labels_val = final_labels._data[neighbor_ids] - dists_val = final_dists._data[neighbor_ids] - sum_dists_vox = np.sum(dists_val) - weights_vox = np.exp(-dists_val / sum_dists_vox) + print('a', time()-t0) - vote = np.bincount(labels_val, weights=weights_vox) - total = np.arange(np.amax(labels_val+1)) - winner = total[np.argmax(vote)] - labels_map[ind[0], ind[1], ind[2]] = winner - distance_map[ind[0], ind[1], ind[2]] = np.average(dists_val) - cmap = get_colormap(args.colormap) + cmap = get_colormap(args.colormap) for i, sft in enumerate(sft_list): if len(sft_list) > 1: sub_out_dir = os.path.join(args.out_dir, 'session_{}'.format(i+1)) @@ -246,9 +256,9 @@ def main(): nib.save(nib.Nifti1Image((binary_list[i]*labels_map).astype(np.uint16), sft_list[0].affine), os.path.join(sub_out_dir, 'labels_map.nii.gz')) - nib.save(nib.Nifti1Image(binary_list[i]*distance_map, - sft_list[0].affine), - os.path.join(sub_out_dir, 'distance_map.nii.gz')) + # nib.save(nib.Nifti1Image(binary_list[i]*distance_map, + # sft_list[0].affine), + # os.path.join(sub_out_dir, 'distance_map.nii.gz')) nib.save(nib.Nifti1Image(binary_list[i]*corr_map, sft_list[0].affine), os.path.join(sub_out_dir, 'correlation_map.nii.gz')) @@ -257,9 +267,9 @@ def main(): tmp_labels = ndi.map_coordinates(labels_map, sft.streamlines._data.T-0.5, order=0) - tmp_dists = ndi.map_coordinates(distance_map, - sft.streamlines._data.T-0.5, - order=0) + # tmp_dists = ndi.map_coordinates(distance_map, + # sft.streamlines._data.T-0.5, + # order=0) tmp_corr = ndi.map_coordinates(corr_map, sft.streamlines._data.T-0.5, order=0) @@ -273,11 +283,11 @@ def main(): save_tractogram(new_sft, os.path.join(sub_out_dir, 'labels.trk')) - if len(sft): - new_sft.data_per_point['color']._data = cmap( - tmp_dists / np.max(tmp_dists))[:, 0:3] * 255 - save_tractogram(new_sft, - os.path.join(sub_out_dir, 'distance.trk')) + # if len(sft): + # new_sft.data_per_point['color']._data = cmap( + # tmp_dists / np.max(tmp_dists))[:, 0:3] * 255 + # save_tractogram(new_sft, + # os.path.join(sub_out_dir, 'distance.trk')) if len(sft): new_sft.data_per_point['color']._data = cmap(tmp_corr)[ From db64932847ac9b3255c035b37aae786ae9e499bf Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 31 Oct 2023 11:37:24 -0400 Subject: [PATCH 02/14] Cleaner --- scilpy/tractanalysis/distance_to_centroid.py | 5 ++- .../scil_compute_bundle_voxel_label_map.py | 38 +++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index e27ef3e98..07df19898 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -7,8 +7,9 @@ def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts): tree = KDTree(centroid_pts, copy_data=True) dists, labels = tree.query(bundle_pts, k=1) - dists, labels = np.expand_dims( - dists, axis=1), np.expand_dims(labels, axis=1) + + dists = np.expand_dims(dists, axis=1) + labels = np.expand_dims(labels, axis=1) labels = np.mod(labels, nb_pts) diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 973bb0db9..41b7b2712 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -207,37 +207,37 @@ def main(): t0 = time() if not args.new_labelling: - labels, _ = min_dist_to_centroid(concat_sft.streamlines._data, + indices = np.array(np.nonzero(binary_bundle), dtype=int).T + labels, _ = min_dist_to_centroid(indices, sft_centroid.streamlines._data, args.nb_pts) labels += 1 # 0 means no labels labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) - indices = np.array(np.nonzero(binary_bundle), dtype=int).T - - kd_tree = cKDTree(concat_sft.streamlines._data) - for ind in indices: - _, neighbor_ids = kd_tree.query(ind, k=5) - - if not len(neighbor_ids): - continue - - labels_val = labels[neighbor_ids] - - vote = np.bincount(labels_val) - total = np.arange(np.amax(labels_val+1)) - winner = total[np.argmax(vote)] - labels_map[ind[0], ind[1], ind[2]] = winner - + labels_map[np.where(binary_bundle)] = labels else: svc = SVC(C=1, kernel='rbf') - labels = np.tile(np.arange(0,args.nb_pts)[::-1], len(sft_centroid)) + #def transfer_and_diffuse_labels(sft_source, sft_target): + + labels = np.tile(np.arange(0,args.nb_pts)[::-1], len(tmp_sft)) labels += 1 - svc.fit(X=sft_centroid.streamlines._data, y=labels) + svc.fit(X=tmp_sft.streamlines._data, y=labels) labels_pred = svc.predict(X=np.array(np.where(binary_bundle)).T) labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) labels_map[np.where(binary_bundle)] = labels_pred + exp_labels = np.tile(np.arange(0,args.nb_pts)[::-1], len(sft_centroid)) + exp_labels += 1 + svc.fit(X=sft_centroid.streamlines._data, y=exp_labels) + + exp_labels_pred = svc.predict(X=np.array(np.where(binary_bundle)).T) + exp_labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + exp_labels_map[np.where(binary_bundle)] = exp_labels_pred + + mapping = compute_overlap_and_mapping(labels_map, exp_labels_map) + print(mapping) + + print('a', time()-t0) From 71840d9a1f00049a73ed4ea48200f2c7b715e784 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 1 Nov 2023 10:44:06 -0400 Subject: [PATCH 03/14] working distance --- scilpy/tractanalysis/distance_to_centroid.py | 148 ++++++++-- .../scil_compute_bundle_voxel_label_map.py | 255 +++++++++++------- 2 files changed, 285 insertions(+), 118 deletions(-) diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index 07df19898..98e066003 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,26 +1,144 @@ # -*- coding: utf-8 -*- import numpy as np +from scipy.ndimage import binary_dilation from scipy.spatial import KDTree -def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts): - tree = KDTree(centroid_pts, copy_data=True) - dists, labels = tree.query(bundle_pts, k=1) +def transfer_and_diffuse_labels(target_sft, source_sft, nb_pts=20,): + tree = KDTree(source_sft.streamlines._data, copy_data=True) + pts_ids = tree.query_ball_point(target_sft.streamlines._data, r=4) - dists = np.expand_dims(dists, axis=1) - labels = np.expand_dims(labels, axis=1) + max_count_labels = [] + for pts_id in pts_ids: + if not pts_id: # If no source point is close enough + max_count_labels.append(-1) + continue + labels = np.mod(pts_id, nb_pts) + 1 + unique_labels, counts = np.unique(labels, return_counts=True) + max_count_label = unique_labels[np.argmax(counts)] + max_count_labels.append(max_count_label) + + labels = np.array(max_count_labels, dtype=np.uint16) - labels = np.mod(labels, nb_pts) + curr_ind = 0 + for _, streamline in enumerate(target_sft.streamlines): + next_ind = curr_ind + len(streamline) + curr_labels = labels[curr_ind:next_ind] + labels[curr_ind:next_ind] = diffuse_labels(streamline, curr_labels) + curr_ind = next_ind - sum_dist = np.expand_dims(np.sum(dists, axis=1), axis=1) - weights = np.exp(-dists / sum_dist) + return labels - votes = [] - for i in range(len(bundle_pts)): - vote = np.bincount(labels[i], weights=weights[i]) - total = np.arange(np.amax(labels[i])+1) - winner = total[np.argmax(vote)] - votes.append(winner) - return np.array(votes, dtype=np.uint16), np.average(dists, axis=1) + +def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, + pre_computed_labels=None): + if nb_pts is None and pre_computed_labels is None: + raise ValueError('Either nb_pts or labels must be provided.') + + tree = KDTree(source_pts, copy_data=True) + _, labels = tree.query(target_pts, k=1) + + if pre_computed_labels is None: + labels = np.mod(labels, nb_pts) + 1 + else: + labels = pre_computed_labels[labels] + + return labels.astype(np.uint16) + + +def diffuse_labels(streamline, labels): + """ + Replace -1 labels in the polyline using a diffusion algorithm. + + Parameters: + streamline (ndarray): Coordinates of the polyline. + labels (ndarray): Labels corresponding to the points in the polyline. + + Returns: + ndarray: Updated labels with -1 replaced. + """ + iteration = 0 + while np.any(labels == 65535): # Continue until no -1 labels are left + for i, label in enumerate(labels): + if label == 65535: + # Find closest point with a non-negative label + min_distance = np.inf + closest_label = -1 + for j, other_label in enumerate(labels): + if other_label != 65535: + distance = np.linalg.norm(streamline[i]-streamline[j]) + if distance < min_distance: + min_distance = distance + closest_label = other_label + # Update the label + if iteration > 10: + labels[i] = 1 + labels[i] = closest_label + return labels + +from scipy.spatial.distance import pdist, squareform + +def find_medoid(points): + """ + Find the medoid among a set of points. + + Parameters: + points (ndarray): Points in N-dimensional space. + + Returns: + ndarray: Coordinates of the medoid. + """ + distance_matrix = squareform(pdist(points)) + medoid_idx = np.argmin(distance_matrix.sum(axis=1)) + return points[medoid_idx] + + +def compute_shell_barycenters(labels_map): + """ + Compute the barycenter for each label in a 3D NumPy array by maximizing + the distance to the boundary. + + Parameters: + labels_map (ndarray): The 3D array containing labels from 1-nb_pts. + + Returns: + ndarray: An array of size (nb_pts, 3) containing the barycenter + for each label. + """ + labels = np.unique(labels_map)[1:] + barycenters = np.zeros((len(labels), 3)) + + for label in labels: + mask = np.zeros_like(labels_map) + mask[labels_map == label] = 1 + mask_coords = np.argwhere(mask) + + barycenter = find_medoid(mask_coords) + barycenters[label - 1] = barycenter + + return barycenters + + +def compute_euclidean_barycenters(labels_map): + """ + Compute the euclidean barycenter for each label in a 3D NumPy array. + + Parameters: + labels_map (ndarray): The 3D array containing labels from 1-nb_pts. + + Returns: + ndarray: A NumPy array of shape (nb_pts, 3) containing the barycenter + for each label. + """ + labels = np.unique(labels_map)[1:] + barycenters = np.zeros((len(labels), 3)) + + for label in labels: + indices = np.argwhere(labels_map == label) + if indices.size > 0: + barycenter = np.mean(indices, axis=0) + barycenters[label-1, :] = barycenter + + return barycenters diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 41b7b2712..72c773537 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -8,12 +8,15 @@ The number of labels will be the same as the centroid's number of points. """ +from sklearn.svm import SVC +from collections import defaultdict +from time import time import argparse import os -from dipy.align.streamlinear import StreamlineLinearRegistration +# from dipy.align.streamlinear import BundleMinDistanceMetric, StreamlineLinearRegistration from dipy.io.streamline import save_tractogram -from dipy.io.stateful_tractogram import StatefulTractogram, set_sft_logger_level, Space +from dipy.io.stateful_tractogram import StatefulTractogram, set_sft_logger_level, Space, Origin from dipy.io.utils import is_header_compatible import matplotlib.pyplot as plt import nibabel as nib @@ -30,7 +33,10 @@ assert_inputs_exist, assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -from scilpy.tractanalysis.distance_to_centroid import min_dist_to_centroid +from scilpy.tractanalysis.distance_to_centroid import (min_dist_to_centroid, + compute_euclidean_barycenters, + compute_shell_barycenters, + transfer_and_diffuse_labels) from scilpy.tractograms.streamline_and_mask_operations import \ cut_outside_of_mask_streamlines from scilpy.tractograms.streamline_operations import resample_streamlines_num_points @@ -40,7 +46,6 @@ from scilpy.viz.utils import get_colormap - def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, @@ -67,55 +72,9 @@ def _build_arg_parser(): return p -from sklearn.svm import SVC -from time import time - -from collections import defaultdict - -def compute_overlap_and_mapping(small_data, full_data): - """ - Compute the overlap between labels in the small and full datasets - and generate a mapping based on maximizing this overlap. - - Parameters: - - small_data: np.ndarray, data from the smaller image - - full_data: np.ndarray, data from the full image - - unique_small: np.ndarray, unique labels in the smaller image - - unique_full: np.ndarray, unique labels in the full image - - Returns: - - dict, mapping from labels in the smaller image to labels in the full image - """ - mapping = {} - unique_small = np.unique(small_data) - unique_full = np.unique(full_data) - for label_small in unique_small: - if label_small == 0: - continue # Skip background - - overlaps = defaultdict(int) - mask_small = small_data == label_small - - for label_full in unique_full: - if label_full == 0: - continue # Skip background - - mask_full = full_data == label_full - overlap = np.sum(mask_small & mask_full) - - if overlap > 0: - overlaps[label_full] = overlap - - if overlaps: - best_match = max(overlaps, key=overlaps.get) - mapping[label_small] = best_match - else: - # If no overlap found, continue with default increasing +1 scheme - mapping[label_small] = label_small + 1 - - return mapping def main(): + t0 = time() parser = _build_arg_parser() args = parser.parse_args() set_sft_logger_level('ERROR') @@ -143,7 +102,8 @@ def main(): if not is_header_compatible(sft_list[0], sft_list[-1]): parser.error('Header of {} and {} are not compatible'.format( args.in_bundles[0], filename)) - + print('Loading time', time()-t0) + t0 = time() density_list = [] binary_list = [] for sft in sft_list: @@ -187,60 +147,156 @@ def main(): binary_bundle) if len(sft_list[i]): concat_sft += sft_list[i] - + print('Chop time', time()-t0) + t0 = time() args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \ else args.nb_pts sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0]) - tmp_sft = resample_streamlines_num_points(concat_sft[0:1000], args.nb_pts) - - if not args.new_labelling: - new_streamlines = sft_centroid.streamlines.copy() - sft_centroid = StatefulTractogram.from_sft([new_streamlines[0]], - sft_centroid) - else: - srr = StreamlineLinearRegistration() - srm = srr.optimize(static=tmp_sft.streamlines, - moving=sft_centroid.streamlines) - sft_centroid.streamlines = srm.transform(sft_centroid.streamlines) + tmp_sft = resample_streamlines_num_points(concat_sft[0:2500], args.nb_pts) + print('Uni+SLR time', time()-t0) + t0 = time() t0 = time() if not args.new_labelling: indices = np.array(np.nonzero(binary_bundle), dtype=int).T - labels, _ = min_dist_to_centroid(indices, - sft_centroid.streamlines._data, - args.nb_pts) - labels += 1 # 0 means no labels + labels = min_dist_to_centroid(indices, + sft_centroid[0].streamlines._data, + nb_pts=args.nb_pts) labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) labels_map[np.where(binary_bundle)] = labels + barycenters = compute_euclidean_barycenters(labels_map) + nib.save(nib.Nifti1Image(labels_map, sft_list[0].affine), + os.path.join(args.out_dir, 'labels_map_1.nii.gz')) + labels = ndi.map_coordinates(labels_map, + concat_sft.streamlines._data.T-0.5, + order=0) + print('Euclidian time', time()-t0) + t0 = time() else: - svc = SVC(C=1, kernel='rbf') - #def transfer_and_diffuse_labels(sft_source, sft_target): - - labels = np.tile(np.arange(0,args.nb_pts)[::-1], len(tmp_sft)) - labels += 1 + svc = SVC(C=1, kernel='rbf', cache_size=1000) + labels = transfer_and_diffuse_labels(tmp_sft, sft_centroid) + print('Diffuse time', time()-t0) + t0 = time() svc.fit(X=tmp_sft.streamlines._data, y=labels) + print('Fit time', time()-t0) + t0 = time() + # print(exp_labels) - labels_pred = svc.predict(X=np.array(np.where(binary_bundle)).T) - labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) - labels_map[np.where(binary_bundle)] = labels_pred - - exp_labels = np.tile(np.arange(0,args.nb_pts)[::-1], len(sft_centroid)) - exp_labels += 1 - svc.fit(X=sft_centroid.streamlines._data, y=exp_labels) + exp_labels = svc.predict(X=np.array(np.where(binary_bundle)).T) + print('Predict time', time()-t0) + t0 = time() - exp_labels_pred = svc.predict(X=np.array(np.where(binary_bundle)).T) exp_labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) - exp_labels_map[np.where(binary_bundle)] = exp_labels_pred - - mapping = compute_overlap_and_mapping(labels_map, exp_labels_map) - print(mapping) - - - print('a', time()-t0) - - + exp_labels_map[np.where(binary_bundle)] = exp_labels + barycenters = compute_shell_barycenters(exp_labels_map) + exp_labels = svc.predict(X=barycenters) + + labels = ndi.map_coordinates(exp_labels_map, + concat_sft.streamlines._data.T-0.5, + order=0) + print('Map making time', time()-t0) + t0 = time() + + barycenter_sft = StatefulTractogram([barycenters], sft_centroid, + space=Space.VOX, origin=Origin.TRACKVIS) + + # It is not allowed that labels jumps labels for consistency + # Streamlines should have continous labels + final_streamlines = [] + final_label = [] + curr_ind = 0 + for i, streamline in enumerate(concat_sft.streamlines): + next_ind = curr_ind + len(streamline) + curr_labels = labels[curr_ind:next_ind] + curr_ind = next_ind + + # Flip streamlines so the labels increase (facilitate if/else) + # Should always be ordered in nextflow pipeline + gradient = np.gradient(curr_labels) + if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): + streamline = streamline[::-1] + curr_labels = curr_labels[::-1] + + # # Find jumps, cut them and find the longest + gradient = np.ediff1d(curr_labels) + max_jump = 2 + if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: + pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 + split_chunk = np.split(curr_labels, + pos_jump) + + max_len = 0 + max_pos = 0 + for j, chunk in enumerate(split_chunk): + if len(chunk) > max_len: + max_len = len(chunk) + max_pos = j + + curr_labels = split_chunk[max_pos] + gradient_chunk = np.ediff1d(chunk) + if len(np.unique(np.sign(gradient_chunk))) > 1: + continue + streamline = np.split(streamline, + pos_jump)[max_pos] + + final_streamlines.append(streamline) + final_label.append(curr_labels) + + final_streamlines = ArraySequence(final_streamlines) + final_labels = ArraySequence(final_label) + + indices = np.array(np.nonzero(binary_bundle), dtype=int).T + labels = min_dist_to_centroid(indices, + final_streamlines._data, + pre_computed_labels=final_labels._data) + labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + labels_map[np.where(binary_bundle)] = labels + print('Clean Up time', time()-t0) + t0 = time() + dists = np.ones(binary_bundle.shape, dtype=float) * -1 + + save_tractogram(barycenter_sft, os.path.join(args.out_dir, + 'barycenters.trk')) + + import dijkstra3d + barycenter_bin = compute_tract_counts_map(barycenter_sft.streamlines, + barycenter_sft.dimensions) + barycenter_bin[barycenter_bin > 0] = 1 + # for label in range(1, args.nb_pts+1): + # indices = np.array(np.nonzero(labels_map == label), dtype=int).T + # field = np.ones(labels_map.shape, dtype=float) + # for ind in indices: + # ind = tuple(ind) + # + # path = dijkstra3d.dijkstra(field, barycenter, ind, compass=True) + # dists[ind] = len(path)-1 + for label in range(1, args.nb_pts+1): + mask = np.zeros(labels_map.shape) + mask[labels_map == label] = 1 + barycenter_bin_intersect = barycenter_bin * mask + barycenter_intersect_coords = np.array(np.nonzero(barycenter_bin_intersect), + dtype=int).T + bundle_disjoint, num_labels = ndi.label(mask) + iterations = 0 + + while num_labels > 1: + mask = ndi.binary_dilation(mask) + bundle_disjoint, num_labels = ndi.label(mask) + iterations += 1 + print('a', label, iterations, num_labels) + + barycenter = tuple(np.round(barycenters[label-1]).astype(int)) + print(label, labels_map[barycenter], barycenter) + curr_dists = dijkstra3d.distance_field(mask, + source=barycenter_intersect_coords) + dists[labels_map == label] = curr_dists[labels_map == label] + print(np.unique(curr_dists, return_counts=True)) + print() + + print('Dijkstra time', time()-t0) + t0 = time() cmap = get_colormap(args.colormap) for i, sft in enumerate(sft_list): @@ -256,20 +312,17 @@ def main(): nib.save(nib.Nifti1Image((binary_list[i]*labels_map).astype(np.uint16), sft_list[0].affine), os.path.join(sub_out_dir, 'labels_map.nii.gz')) - # nib.save(nib.Nifti1Image(binary_list[i]*distance_map, - # sft_list[0].affine), - # os.path.join(sub_out_dir, 'distance_map.nii.gz')) nib.save(nib.Nifti1Image(binary_list[i]*corr_map, sft_list[0].affine), os.path.join(sub_out_dir, 'correlation_map.nii.gz')) + nib.save(nib.Nifti1Image(binary_list[i]*dists, + sft_list[0].affine), + os.path.join(sub_out_dir, 'distance_map.nii.gz')) if len(sft): tmp_labels = ndi.map_coordinates(labels_map, sft.streamlines._data.T-0.5, order=0) - # tmp_dists = ndi.map_coordinates(distance_map, - # sft.streamlines._data.T-0.5, - # order=0) tmp_corr = ndi.map_coordinates(corr_map, sft.streamlines._data.T-0.5, order=0) @@ -283,17 +336,13 @@ def main(): save_tractogram(new_sft, os.path.join(sub_out_dir, 'labels.trk')) - # if len(sft): - # new_sft.data_per_point['color']._data = cmap( - # tmp_dists / np.max(tmp_dists))[:, 0:3] * 255 - # save_tractogram(new_sft, - # os.path.join(sub_out_dir, 'distance.trk')) - - if len(sft): + if len(sft) and len(args.in_bundles) > 1: new_sft.data_per_point['color']._data = cmap(tmp_corr)[ :, 0:3] * 255 save_tractogram(new_sft, os.path.join(sub_out_dir, 'correlation.trk')) + print('Finish time', time()-t0) + t0 = time() if __name__ == '__main__': From e84f08483c5ef6ba0cd7f1b8be79dad1d0167339 Mon Sep 17 00:00:00 2001 From: frheault Date: Thu, 2 Nov 2023 09:47:23 -0400 Subject: [PATCH 04/14] WOrking euclidian and hyperplane --- scilpy/tractanalysis/distance_to_centroid.py | 177 ++++++++-------- .../scil_compute_bundle_voxel_label_map.py | 189 ++++++++---------- 2 files changed, 183 insertions(+), 183 deletions(-) diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index 98e066003..da4a87ae8 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,35 +1,11 @@ # -*- coding: utf-8 -*- +import heapq + +from dipy.tracking.metrics import length import numpy as np -from scipy.ndimage import binary_dilation from scipy.spatial import KDTree - - -def transfer_and_diffuse_labels(target_sft, source_sft, nb_pts=20,): - tree = KDTree(source_sft.streamlines._data, copy_data=True) - pts_ids = tree.query_ball_point(target_sft.streamlines._data, r=4) - - max_count_labels = [] - for pts_id in pts_ids: - if not pts_id: # If no source point is close enough - max_count_labels.append(-1) - continue - labels = np.mod(pts_id, nb_pts) + 1 - unique_labels, counts = np.unique(labels, return_counts=True) - max_count_label = unique_labels[np.argmax(counts)] - max_count_labels.append(max_count_label) - - labels = np.array(max_count_labels, dtype=np.uint16) - - curr_ind = 0 - for _, streamline in enumerate(target_sft.streamlines): - next_ind = curr_ind + len(streamline) - curr_labels = labels[curr_ind:next_ind] - labels[curr_ind:next_ind] = diffuse_labels(streamline, curr_labels) - curr_ind = next_ind - - return labels - +from scipy.spatial.distance import pdist, squareform def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, @@ -48,45 +24,52 @@ def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, return labels.astype(np.uint16) -def diffuse_labels(streamline, labels): - """ - Replace -1 labels in the polyline using a diffusion algorithm. +def associate_labels(target_sft, source_sft, + nb_pts=20): + kdtree = KDTree(source_sft.streamlines._data) - Parameters: - streamline (ndarray): Coordinates of the polyline. - labels (ndarray): Labels corresponding to the points in the polyline. + # Initialize vote counters + head_votes = np.zeros(nb_pts, dtype=int) + tail_votes = np.zeros(nb_pts, dtype=int) - Returns: - ndarray: Updated labels with -1 replaced. - """ - iteration = 0 - while np.any(labels == 65535): # Continue until no -1 labels are left - for i, label in enumerate(labels): - if label == 65535: - # Find closest point with a non-negative label - min_distance = np.inf - closest_label = -1 - for j, other_label in enumerate(labels): - if other_label != 65535: - distance = np.linalg.norm(streamline[i]-streamline[j]) - if distance < min_distance: - min_distance = distance - closest_label = other_label - # Update the label - if iteration > 10: - labels[i] = 1 - labels[i] = closest_label - return labels + for streamline in target_sft.streamlines: + head = streamline[0] + tail = streamline[-1] + + # Find closest IDs in the target + closest_head_id = kdtree.query(head)[1] + closest_tail_id = kdtree.query(tail)[1] + + # Knowing the centroids are already labels correctly, their + # label is the modulo of the ID (based on nb_pts) + closest_head_label = np.mod(closest_head_id, nb_pts) + 1 + closest_tail_label = np.mod(closest_tail_id, nb_pts) + 1 + head_votes[closest_head_label - 1] += 1 + tail_votes[closest_tail_label - 1] += 1 + + # Trouver l'étiquette avec le plus de votes + most_voted_head = np.argmax(head_votes) + 1 + most_voted_tail = np.argmax(tail_votes) + 1 + + labels = [] + for i in range(len(target_sft)): + streamline = target_sft.streamlines[i] + lengths = np.insert(length(streamline, along=True), 0, 0) + lengths = (lengths / np.max(lengths)) * \ + (most_voted_tail - most_voted_head) + most_voted_head + + labels = np.concatenate((labels, lengths)) + + return labels.astype(np.uint16), most_voted_head, most_voted_tail -from scipy.spatial.distance import pdist, squareform def find_medoid(points): """ Find the medoid among a set of points. - + Parameters: points (ndarray): Points in N-dimensional space. - + Returns: ndarray: Coordinates of the medoid. """ @@ -95,50 +78,82 @@ def find_medoid(points): return points[medoid_idx] -def compute_shell_barycenters(labels_map): +def compute_labels_map_barycenters(labels_map, euclidian=False, nb_pts=False): """ Compute the barycenter for each label in a 3D NumPy array by maximizing the distance to the boundary. - + Parameters: labels_map (ndarray): The 3D array containing labels from 1-nb_pts. - + euclidian (bool): If True, the barycenter is the mean of the points + Returns: ndarray: An array of size (nb_pts, 3) containing the barycenter for each label. """ - labels = np.unique(labels_map)[1:] + labels = np.arange(1, nb_pts+1) if nb_pts else np.unique(labels_map)[1:] barycenters = np.zeros((len(labels), 3)) + barycenters[:] = np.NAN for label in labels: - mask = np.zeros_like(labels_map) - mask[labels_map == label] = 1 - mask_coords = np.argwhere(mask) + indices = np.argwhere(labels_map == label) + if indices.size > 0: + mask = np.zeros_like(labels_map) + mask[labels_map == label] = 1 + mask_coords = np.argwhere(mask) + + if euclidian: + barycenter = np.mean(mask_coords, axis=0) + else: + barycenter = find_medoid(mask_coords) + if labels_map[tuple(barycenter.astype(int))] != label: + tree = KDTree(indices) + _, ind = tree.query(barycenter, k=1) + barycenter = indices[ind] - barycenter = find_medoid(mask_coords) - barycenters[label - 1] = barycenter + barycenters[label - 1] = barycenter - return barycenters + return np.array(barycenters) -def compute_euclidean_barycenters(labels_map): +def masked_manhattan_distance(mask, target_positions): """ - Compute the euclidean barycenter for each label in a 3D NumPy array. + Compute the Manhattan distance from every position in a mask to a set of positions, + without stepping out of the mask. Parameters: - labels_map (ndarray): The 3D array containing labels from 1-nb_pts. + mask (ndarray): A binary 3D array representing the mask. + target_positions (list): A list of target positions within the mask. Returns: - ndarray: A NumPy array of shape (nb_pts, 3) containing the barycenter - for each label. + ndarray: A 3D array of the same shape as the mask, containing the Manhattan distances. """ - labels = np.unique(labels_map)[1:] - barycenters = np.zeros((len(labels), 3)) + # Initialize distance array with infinite values + distances = np.full(mask.shape, np.inf) - for label in labels: - indices = np.argwhere(labels_map == label) - if indices.size > 0: - barycenter = np.mean(indices, axis=0) - barycenters[label-1, :] = barycenter + # Initialize priority queue and set distance for target positions to zero + priority_queue = [] + for x, y, z in target_positions: + heapq.heappush(priority_queue, (0, (x, y, z))) + distances[x, y, z] = 0 + + # Directions for moving in the grid (Manhattan distance) + directions = [(0, 0, 1), (0, 0, -1), (0, 1, 0), + (0, -1, 0), (1, 0, 0), (-1, 0, 0)] + + while priority_queue: + current_distance, (x, y, z) = heapq.heappop(priority_queue) + + for dx, dy, dz in directions: + nx, ny, nz = x + dx, y + dy, z + dz + + if 0 <= nx < mask.shape[0] and 0 <= ny < mask.shape[1] and 0 <= nz < mask.shape[2]: + if mask[nx, ny, nz]: + new_distance = current_distance + 1 + + if new_distance < distances[nx, ny, nz]: + distances[nx, ny, nz] = new_distance + heapq.heappush( + priority_queue, (new_distance, (nx, ny, nz))) - return barycenters + return distances diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 72c773537..c9a2fd512 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -23,8 +23,6 @@ from nibabel.streamlines.array_sequence import ArraySequence import numpy as np import scipy.ndimage as ndi -from scipy.spatial import cKDTree -from scipy.ndimage import binary_erosion from scilpy.image.volume_math import correlation from scilpy.io.streamlines import load_tractogram_with_reference @@ -34,14 +32,13 @@ assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractanalysis.distance_to_centroid import (min_dist_to_centroid, - compute_euclidean_barycenters, - compute_shell_barycenters, - transfer_and_diffuse_labels) + compute_labels_map_barycenters, + associate_labels, + masked_manhattan_distance) from scilpy.tractograms.streamline_and_mask_operations import \ cut_outside_of_mask_streamlines -from scilpy.tractograms.streamline_operations import resample_streamlines_num_points -from scilpy.tractograms.streamline_and_mask_operations import \ - get_head_tail_density_maps +from scilpy.tractograms.streamline_operations import \ + resample_streamlines_num_points, resample_streamlines_step_size from scilpy.utils.streamlines import uniformize_bundle_sft from scilpy.viz.utils import get_colormap @@ -153,54 +150,40 @@ def main(): else args.nb_pts sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) - uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0]) - tmp_sft = resample_streamlines_num_points(concat_sft[0:2500], args.nb_pts) - print('Uni+SLR time', time()-t0) - t0 = time() + # Select 2000 elements from the SFTs + random_indices = np.random.choice(len(concat_sft), 2000, + replace=False) + tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], 2.0) + t0 = time() if not args.new_labelling: indices = np.array(np.nonzero(binary_bundle), dtype=int).T labels = min_dist_to_centroid(indices, sft_centroid[0].streamlines._data, nb_pts=args.nb_pts) - labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + labels_map = np.zeros(binary_bundle.shape, dtype=np.uint16) labels_map[np.where(binary_bundle)] = labels - barycenters = compute_euclidean_barycenters(labels_map) - nib.save(nib.Nifti1Image(labels_map, sft_list[0].affine), - os.path.join(args.out_dir, 'labels_map_1.nii.gz')) labels = ndi.map_coordinates(labels_map, concat_sft.streamlines._data.T-0.5, order=0) - print('Euclidian time', time()-t0) - t0 = time() else: + # The head and tail are the labels (not the indices) + labels, _, _ = associate_labels(tmp_sft, sft_centroid, + args.nb_pts) + svc = SVC(C=1, kernel='rbf', cache_size=1000) - labels = transfer_and_diffuse_labels(tmp_sft, sft_centroid) - print('Diffuse time', time()-t0) - t0 = time() svc.fit(X=tmp_sft.streamlines._data, y=labels) - print('Fit time', time()-t0) - t0 = time() - # print(exp_labels) - exp_labels = svc.predict(X=np.array(np.where(binary_bundle)).T) - print('Predict time', time()-t0) - t0 = time() - exp_labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + exp_labels_map = np.zeros(binary_bundle.shape, dtype=np.uint16) exp_labels_map[np.where(binary_bundle)] = exp_labels - barycenters = compute_shell_barycenters(exp_labels_map) - exp_labels = svc.predict(X=barycenters) labels = ndi.map_coordinates(exp_labels_map, concat_sft.streamlines._data.T-0.5, order=0) - print('Map making time', time()-t0) - t0 = time() - - barycenter_sft = StatefulTractogram([barycenters], sft_centroid, - space=Space.VOX, origin=Origin.TRACKVIS) + print('Map making time', time()-t0) + t0 = time() # It is not allowed that labels jumps labels for consistency # Streamlines should have continous labels @@ -251,50 +234,57 @@ def main(): labels = min_dist_to_centroid(indices, final_streamlines._data, pre_computed_labels=final_labels._data) - labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) + labels_map = np.zeros(binary_bundle.shape, dtype=np.uint16) labels_map[np.where(binary_bundle)] = labels print('Clean Up time', time()-t0) t0 = time() - dists = np.ones(binary_bundle.shape, dtype=float) * -1 - - save_tractogram(barycenter_sft, os.path.join(args.out_dir, - 'barycenters.trk')) - - import dijkstra3d - barycenter_bin = compute_tract_counts_map(barycenter_sft.streamlines, - barycenter_sft.dimensions) + + barycenters = compute_labels_map_barycenters(labels_map, + euclidian=args.new_labelling, + nb_pts=args.nb_pts) + + isnan = np.isnan(barycenters).all(axis=1) + # These two return the labels (not the indices) + head = np.argmax(~isnan) + 1 + tail = len(isnan) - np.argmax(~isnan[::-1]) + + distance_map = np.zeros(binary_bundle.shape, dtype=float) + barycenter_strs = [barycenters[head-1:tail]] + barycenter_bin = compute_tract_counts_map(barycenter_strs, + sft_centroid.dimensions) barycenter_bin[barycenter_bin > 0] = 1 - # for label in range(1, args.nb_pts+1): - # indices = np.array(np.nonzero(labels_map == label), dtype=int).T - # field = np.ones(labels_map.shape, dtype=float) - # for ind in indices: - # ind = tuple(ind) - # - # path = dijkstra3d.dijkstra(field, barycenter, ind, compass=True) - # dists[ind] = len(path)-1 - for label in range(1, args.nb_pts+1): + for label in range(head, tail+1): mask = np.zeros(labels_map.shape) mask[labels_map == label] = 1 + labels_coords = np.array(np.where(mask)).T + if labels_coords.size == 0: + continue + barycenter_bin_intersect = barycenter_bin * mask - barycenter_intersect_coords = np.array(np.nonzero(barycenter_bin_intersect), - dtype=int).T - bundle_disjoint, num_labels = ndi.label(mask) - iterations = 0 - - while num_labels > 1: - mask = ndi.binary_dilation(mask) + barycenter_intersect_coords = np.array( + np.nonzero(barycenter_bin_intersect), dtype=int).T + + if barycenter_intersect_coords.size == 0: + continue + + if not args.new_labelling: + distances = np.linalg.norm(barycenter_intersect_coords[:, np.newaxis] - + labels_coords, axis=-1) + distance_map[labels_map == label] = np.min(distances, axis=0) + + else: bundle_disjoint, num_labels = ndi.label(mask) - iterations += 1 - print('a', label, iterations, num_labels) + iterations = 0 - barycenter = tuple(np.round(barycenters[label-1]).astype(int)) - print(label, labels_map[barycenter], barycenter) - curr_dists = dijkstra3d.distance_field(mask, - source=barycenter_intersect_coords) - dists[labels_map == label] = curr_dists[labels_map == label] - print(np.unique(curr_dists, return_counts=True)) - print() + while num_labels > 1: + mask = ndi.binary_dilation(mask) + bundle_disjoint, num_labels = ndi.label(mask) + iterations += 1 + coords = [tuple(coord) for coord in barycenter_intersect_coords] + curr_dists = masked_manhattan_distance(mask, coords) + distance_map[labels_map == + label] = curr_dists[labels_map == label]+1 print('Dijkstra time', time()-t0) t0 = time() @@ -304,43 +294,38 @@ def main(): sub_out_dir = os.path.join(args.out_dir, 'session_{}'.format(i+1)) else: sub_out_dir = args.out_dir + new_sft = StatefulTractogram.from_sft(sft.streamlines, sft_list[0]) + new_sft.data_per_point['color'] = ArraySequence(new_sft.streamlines) if not os.path.isdir(sub_out_dir): os.mkdir(sub_out_dir) - # Save each session map if multiple inputs - nib.save(nib.Nifti1Image((binary_list[i]*labels_map).astype(np.uint16), - sft_list[0].affine), - os.path.join(sub_out_dir, 'labels_map.nii.gz')) - nib.save(nib.Nifti1Image(binary_list[i]*corr_map, - sft_list[0].affine), - os.path.join(sub_out_dir, 'correlation_map.nii.gz')) - nib.save(nib.Nifti1Image(binary_list[i]*dists, - sft_list[0].affine), - os.path.join(sub_out_dir, 'distance_map.nii.gz')) - - if len(sft): - tmp_labels = ndi.map_coordinates(labels_map, - sft.streamlines._data.T-0.5, - order=0) - tmp_corr = ndi.map_coordinates(corr_map, - sft.streamlines._data.T-0.5, - order=0) - cmap = plt.colormaps[args.colormap] - new_sft.data_per_point['color'] = ArraySequence( - new_sft.streamlines) - - # Nicer visualisation for MI-Brain - new_sft.data_per_point['color']._data = cmap( - tmp_labels / np.max(tmp_labels))[:, 0:3] * 255 - save_tractogram(new_sft, - os.path.join(sub_out_dir, 'labels.trk')) - - if len(sft) and len(args.in_bundles) > 1: - new_sft.data_per_point['color']._data = cmap(tmp_corr)[ - :, 0:3] * 255 - save_tractogram(new_sft, - os.path.join(sub_out_dir, 'correlation.trk')) + # Dictionary to hold the data for each type + data_dict = {'labels': labels_map.astype(np.uint16), + 'distance': distance_map.astype(float), + 'correlation': corr_map.astype(float)} + + # Iterate through each type to save the files + for basename, map in data_dict.items(): + nib.save(nib.Nifti1Image((binary_list[i] * map), sft_list[0].affine), + os.path.join(sub_out_dir, "{}_map.nii.gz".format(basename))) + + if basename == 'correlation' and len(args.in_bundles) == 1: + continue + + if len(sft): + tmp_data = ndi.map_coordinates( + map, sft.streamlines._data.T - 0.5, order=0) + + max_val = args.nb_pts if basename == 'labels' else np.max( + tmp_data) + new_sft.data_per_point['color']._data = cmap( + tmp_data / max_val)[:, 0:3] * 255 + + # Save the tractogram + save_tractogram(new_sft, + os.path.join(sub_out_dir, + "{}.trk".format(basename))) print('Finish time', time()-t0) t0 = time() From 1bafe3dee5b838e12ac72ba7ad174e2d4a9f448d Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 7 Nov 2023 12:53:50 -0500 Subject: [PATCH 05/14] Best fastest version ever --- scilpy/image/labels.py | 73 +++++++ scilpy/tractanalysis/distance_to_centroid.py | 109 +++++++--- scilpy/utils/streamlines.py | 29 ++- .../scil_compute_bundle_voxel_label_map.py | 201 ++++++------------ 4 files changed, 235 insertions(+), 177 deletions(-) diff --git a/scilpy/image/labels.py b/scilpy/image/labels.py index 4a48b5e82..f6dbb82a4 100644 --- a/scilpy/image/labels.py +++ b/scilpy/image/labels.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import time +from scipy.ndimage import generic_filter import inspect import logging import os @@ -295,3 +297,74 @@ def dilate_labels(data, vox_size, distance, nbr_processes, data = data.reshape(img_shape) return data + + +def weighted_vote_median_filter(labels, density): + """ + Apply a weighted median voting filter on a 3D array of labels using the density and + distance from the center voxel as weights. + + Parameters: + labels (numpy.ndarray): A 3D numpy array of labels. + density (numpy.ndarray): A 3D numpy array of the same shape as labels, representing density/probability. + + Returns: + numpy.ndarray: A 3D numpy array with the filtered labels. + """ + # Precompute the 3x3x3 distance kernel + kernel_size = 3 + pad_width = kernel_size // 2 + + density = density.astype(float) / np.max(density) + + # Generate distances for a 3x3x3 kernel + x, y, z = np.indices((kernel_size, kernel_size, kernel_size)) - pad_width + distances = np.sqrt(x**2 + y**2 + z**2) + weights_distance = 1 / (1 + distances) + + # Pad the labels and density arrays + pad_width = kernel_size // 2 + padded_labels = np.pad(labels, pad_width, mode='constant', + constant_values=0) + padded_density = np.pad(density, pad_width, mode='constant', + constant_values=0) + + # Create an array to hold the new labels + new_labels = np.zeros_like(padded_labels) + + # Iterate over the 3D indices of the original labels array + for ind in np.argwhere(padded_labels > 0): + x, y, z = ind + # Extract the 3x3x3 cube around the current voxel + cube_labels = padded_labels[x-pad_width:x + pad_width + 1, + y-pad_width:y + pad_width + 1, + z-pad_width:z + pad_width + 1] + cube_density = padded_density[x-pad_width:x + pad_width + 1, + y-pad_width:y + pad_width + 1, + z-pad_width:z + pad_width + 1] + + # Compute weights for each label based on density and distance + weights = cube_density * weights_distance + + # Flatten the cube arrays to use in weighted voting + flat_labels = cube_labels.flatten() + flat_weights = weights.flatten() + + # Calculate the weighted count for each label + unique_labels = np.unique(flat_labels) + label_weights = np.zeros_like(unique_labels, dtype=float) + for i, label in enumerate(unique_labels): + label_weights[i] = np.sum(flat_weights[flat_labels == label]) + + # Select the label with the highest weighted count + selected_label = unique_labels[np.argmax(label_weights)] + if np.abs(float(selected_label) - float(padded_labels[x, y, z])) > 1 or \ + selected_label == 0: + new_labels[x, y, z] = padded_labels[x, y, z] + else: + new_labels[x, y, z] = selected_label + new_labels[x, y, z] = unique_labels[np.argmax(label_weights)] + + return new_labels[pad_width:-pad_width, + pad_width:-pad_width, + pad_width:-pad_width] diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index da4a87ae8..23e4a33b6 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +from nibabel.streamlines.array_sequence import ArraySequence import heapq from dipy.tracking.metrics import length @@ -24,43 +25,36 @@ def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, return labels.astype(np.uint16) -def associate_labels(target_sft, source_sft, - nb_pts=20): - kdtree = KDTree(source_sft.streamlines._data) +def associate_labels(target_sft, source_sft, nb_pts=20): + # KDTree for the target streamlines + target_kdtree = KDTree(target_sft.streamlines._data) - # Initialize vote counters - head_votes = np.zeros(nb_pts, dtype=int) - tail_votes = np.zeros(nb_pts, dtype=int) + distances, _ = target_kdtree.query(source_sft.streamlines._data, k=1, + distance_upper_bound=5) + valid_points = distances != np.inf - for streamline in target_sft.streamlines: - head = streamline[0] - tail = streamline[-1] + # Find the first and last indices of non-infinite values + if valid_points.any(): + valid_points = np.mod(np.flatnonzero(valid_points), nb_pts) + labels, count = np.unique(valid_points, return_counts=True) + count = count / np.sum(count) + count[count < 1.0 / (nb_pts*1.5)] = np.NaN + valid_indices = np.where(~np.isnan(count))[0] - # Find closest IDs in the target - closest_head_id = kdtree.query(head)[1] - closest_tail_id = kdtree.query(tail)[1] - - # Knowing the centroids are already labels correctly, their - # label is the modulo of the ID (based on nb_pts) - closest_head_label = np.mod(closest_head_id, nb_pts) + 1 - closest_tail_label = np.mod(closest_tail_id, nb_pts) + 1 - head_votes[closest_head_label - 1] += 1 - tail_votes[closest_tail_label - 1] += 1 - - # Trouver l'étiquette avec le plus de votes - most_voted_head = np.argmax(head_votes) + 1 - most_voted_tail = np.argmax(tail_votes) + 1 + # Find the first and last non-NaN indices + head = labels[valid_indices[0]] + 1 + tail = labels[valid_indices[-1]] + 1 labels = [] for i in range(len(target_sft)): streamline = target_sft.streamlines[i] - lengths = np.insert(length(streamline, along=True), 0, 0) + lengths = np.insert(length(streamline, along=True), 0, 0)[::-1] lengths = (lengths / np.max(lengths)) * \ - (most_voted_tail - most_voted_head) + most_voted_head + (head - tail) + tail labels = np.concatenate((labels, lengths)) - - return labels.astype(np.uint16), most_voted_head, most_voted_tail + + return np.round(labels), head, tail def find_medoid(points): @@ -78,7 +72,7 @@ def find_medoid(points): return points[medoid_idx] -def compute_labels_map_barycenters(labels_map, euclidian=False, nb_pts=False): +def compute_labels_map_barycenters(labels_map, is_euclidian=False, nb_pts=False): """ Compute the barycenter for each label in a 3D NumPy array by maximizing the distance to the boundary. @@ -102,7 +96,7 @@ def compute_labels_map_barycenters(labels_map, euclidian=False, nb_pts=False): mask[labels_map == label] = 1 mask_coords = np.argwhere(mask) - if euclidian: + if is_euclidian: barycenter = np.mean(mask_coords, axis=0) else: barycenter = find_medoid(mask_coords) @@ -157,3 +151,60 @@ def masked_manhattan_distance(mask, target_positions): priority_queue, (new_distance, (nx, ny, nz))) return distances + + +import numpy as np +import scipy.ndimage as ndi +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): + """ + Computes the distance map for each label in the labels_map. + + Parameters: + labels_map (numpy.ndarray): A 3D array representing the labels map. + binary_map (numpy.ndarray): A 3D binary map used to calculate barycenter binary map. + new_labelling (bool): A flag to determine the type of distance calculation. + nb_pts (int): Number of points to use for computing barycenters. + + Returns: + numpy.ndarray: A 3D array representing the distance map. + """ + barycenters = compute_labels_map_barycenters(labels_map, + is_euclidian=new_labelling, + nb_pts=nb_pts) + + isnan = np.isnan(barycenters).all(axis=1) + head = np.argmax(~isnan) + 1 + tail = len(isnan) - np.argmax(~isnan[::-1]) + + distance_map = np.zeros(binary_map.shape, dtype=float) + barycenter_strs = [barycenters[head-1:tail]] + barycenter_bin = compute_tract_counts_map(barycenter_strs, binary_map.shape) + barycenter_bin[barycenter_bin > 0] = 1 + + for label in range(head, tail+1): + mask = np.zeros(labels_map.shape) + mask[labels_map == label] = 1 + labels_coords = np.array(np.where(mask)).T + if labels_coords.size == 0: + continue + + barycenter_bin_intersect = barycenter_bin * mask + barycenter_intersect_coords = np.array(np.nonzero(barycenter_bin_intersect), + dtype=int).T + + if barycenter_intersect_coords.size == 0: + continue + + if not new_labelling: + distances = np.linalg.norm( + barycenter_intersect_coords[:, np.newaxis] - labels_coords, + axis=-1) + distance_map[labels_map == label] = np.min(distances, axis=0) + else: + coords = [tuple(coord) for coord in barycenter_intersect_coords] + curr_dists = masked_manhattan_distance(binary_map, coords) + distance_map[labels_map == label] = \ + curr_dists[labels_map == label] + + return distance_map diff --git a/scilpy/utils/streamlines.py b/scilpy/utils/streamlines.py index ccfe52fb9..bfcdfaccf 100644 --- a/scilpy/utils/streamlines.py +++ b/scilpy/utils/streamlines.py @@ -35,6 +35,8 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): sft.to_corner() density = get_endpoints_density_map(sft.streamlines, sft.dimensions, point_to_select=3) + sft.to_space(old_space) + sft.to_origin(old_origin) indices = np.argwhere(density > 0) kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True, n_init=20).fit(indices) @@ -51,22 +53,33 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): axis_name = ['x', 'y', 'z'] if axis is None or ref_bundle is not None: if ref_bundle is not None: - ref_bundle.to_vox() - ref_bundle.to_corner() centroid = get_streamlines_centroid(ref_bundle.streamlines, 20)[0] else: centroid = get_streamlines_centroid(sft.streamlines, 20)[0] - main_dir_ends = np.argmax(np.abs(centroid[0] - centroid[-1])) - main_dir_displacement = np.argmax( - np.abs(np.sum(np.gradient(centroid, axis=0), axis=0))) + endpoints_distance = np.abs(centroid[0] - centroid[-1]) + total_travel = np.abs( + np.sum(np.gradient(centroid, axis=0), axis=0)) + + # Reweigth the distance to the endpoints and the total travel + # to avoid having a XYbundle oriented in the wrong direction + endpoints_distance[-1] *= 0.9 + total_travel[-1] *= 0.9 + main_dir_ends = np.argmax(endpoints_distance) + main_dir_displacement = np.argmax(total_travel) if main_dir_displacement != main_dir_ends \ or main_dir_displacement != main_dir_barycenter: logging.info('Ambiguity in orientation, you should use --axis') - axis = axis_name[main_dir_displacement] + + # Get the winner + winner = np.zeros(3) + winner[main_dir_displacement] += 1 + winner[main_dir_ends] += 1 + winner[main_dir_barycenter] += 1 + axis_pos = np.argmax(winner) + axis = axis_name[axis_pos] logging.info('Orienting endpoints in the {} axis'.format(axis)) - axis_pos = axis_name.index(axis) if bool(k_means_centers[0][axis_pos] > k_means_centers[1][axis_pos]) ^ bool(swap): @@ -95,8 +108,6 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): for key in sft.data_per_point[i]: sft.data_per_point[key][i] = \ sft.data_per_point[key][i][::-1] - sft.to_space(old_space) - sft.to_origin(old_origin) def uniformize_bundle_sft_using_mask(sft, mask, swap=False): diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index c9a2fd512..c9e2a0d29 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -8,6 +8,10 @@ The number of labels will be the same as the centroid's number of points. """ +from sklearn.preprocessing import MinMaxScaler +from sklearn.linear_model import SGDClassifier +from sklearn.svm import LinearSVC +from sklearn.kernel_approximation import RBFSampler from sklearn.svm import SVC from collections import defaultdict from time import time @@ -24,17 +28,18 @@ import numpy as np import scipy.ndimage as ndi +from scilpy.image.labels import weighted_vote_median_filter from scilpy.image.volume_math import correlation from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, + add_processes_arg, add_reference_arg, assert_inputs_exist, assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractanalysis.distance_to_centroid import (min_dist_to_centroid, - compute_labels_map_barycenters, - associate_labels, - masked_manhattan_distance) + compute_distance_map, + associate_labels) from scilpy.tractograms.streamline_and_mask_operations import \ cut_outside_of_mask_streamlines from scilpy.tractograms.streamline_operations import \ @@ -63,7 +68,10 @@ def _build_arg_parser(): '[%(default)s].') p.add_argument('--new_labelling', action='store_true', help='Use the new labelling method (multi-centroids).') + p.add_argument('--skip_uniformize', action='store_true', + help='Skip uniformization of the bundles orientation.') + add_processes_arg(p) add_reference_arg(p) add_overwrite_arg(p) @@ -82,15 +90,14 @@ def main(): sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) - sft_centroid.to_vox() - sft_centroid.to_corner() - sft_list = [] for filename in args.in_bundles: sft = load_tractogram_with_reference(parser, args, filename) if not len(sft.streamlines): raise IOError('Empty bundle file {}. ' 'Skipping'.format(args.in_bundle)) + if not args.skip_uniformize: + uniformize_bundle_sft(sft, ref_bundle=sft_centroid) sft.to_vox() sft.to_corner() sft_list.append(sft) @@ -99,7 +106,11 @@ def main(): if not is_header_compatible(sft_list[0], sft_list[-1]): parser.error('Header of {} and {} are not compatible'.format( args.in_bundles[0], filename)) - print('Loading time', time()-t0) + + # Perform after the uniformization + sft_centroid.to_vox() + sft_centroid.to_corner() + t0 = time() density_list = [] binary_list = [] @@ -123,15 +134,15 @@ def main(): # Slightly cut the bundle at the edgge to clean up single streamline voxels # with no neighbor. Remove isolated voxels to keep a single 'blob' - binary_bundle = np.zeros(corr_map.shape, dtype=bool) - binary_bundle[corr_map > 0.5] = 1 + binary_map = np.zeros(corr_map.shape, dtype=bool) + binary_map[corr_map > 0.5] = 1 - bundle_disjoint, _ = ndi.label(binary_bundle) + bundle_disjoint, _ = ndi.label(binary_map) unique, count = np.unique(bundle_disjoint, return_counts=True) val = unique[np.argmax(count[1:])+1] - binary_bundle[bundle_disjoint != val] = 0 + binary_map[bundle_disjoint != val] = 0 - corr_map = corr_map*binary_bundle + corr_map = corr_map*binary_map nib.save(nib.Nifti1Image(corr_map, sft_list[0].affine), os.path.join(args.out_dir, 'correlation_map.nii.gz')) @@ -141,10 +152,10 @@ def main(): concat_sft.to_corner() for i in range(len(sft_list)): sft_list[i] = cut_outside_of_mask_streamlines(sft_list[i], - binary_bundle) + binary_map) if len(sft_list[i]): concat_sft += sft_list[i] - print('Chop time', time()-t0) + t0 = time() args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \ else args.nb_pts @@ -152,141 +163,55 @@ def main(): sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) # Select 2000 elements from the SFTs - random_indices = np.random.choice(len(concat_sft), 2000, + random_indices = np.random.choice(len(concat_sft), + min(len(concat_sft), 2000), replace=False) tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], 2.0) t0 = time() if not args.new_labelling: - indices = np.array(np.nonzero(binary_bundle), dtype=int).T + indices = np.array(np.nonzero(binary_map), dtype=int).T labels = min_dist_to_centroid(indices, sft_centroid[0].streamlines._data, nb_pts=args.nb_pts) - labels_map = np.zeros(binary_bundle.shape, dtype=np.uint16) - labels_map[np.where(binary_bundle)] = labels - labels = ndi.map_coordinates(labels_map, - concat_sft.streamlines._data.T-0.5, - order=0) else: - # The head and tail are the labels (not the indices) labels, _, _ = associate_labels(tmp_sft, sft_centroid, args.nb_pts) - svc = SVC(C=1, kernel='rbf', cache_size=1000) - svc.fit(X=tmp_sft.streamlines._data, y=labels) - exp_labels = svc.predict(X=np.array(np.where(binary_bundle)).T) - - exp_labels_map = np.zeros(binary_bundle.shape, dtype=np.uint16) - exp_labels_map[np.where(binary_bundle)] = exp_labels - - labels = ndi.map_coordinates(exp_labels_map, - concat_sft.streamlines._data.T-0.5, - order=0) - print('Map making time', time()-t0) - t0 = time() - - # It is not allowed that labels jumps labels for consistency - # Streamlines should have continous labels - final_streamlines = [] - final_label = [] - curr_ind = 0 - for i, streamline in enumerate(concat_sft.streamlines): - next_ind = curr_ind + len(streamline) - curr_labels = labels[curr_ind:next_ind] - curr_ind = next_ind - - # Flip streamlines so the labels increase (facilitate if/else) - # Should always be ordered in nextflow pipeline - gradient = np.gradient(curr_labels) - if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): - streamline = streamline[::-1] - curr_labels = curr_labels[::-1] - - # # Find jumps, cut them and find the longest - gradient = np.ediff1d(curr_labels) - max_jump = 2 - if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: - pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 - split_chunk = np.split(curr_labels, - pos_jump) - - max_len = 0 - max_pos = 0 - for j, chunk in enumerate(split_chunk): - if len(chunk) > max_len: - max_len = len(chunk) - max_pos = j - - curr_labels = split_chunk[max_pos] - gradient_chunk = np.ediff1d(chunk) - if len(np.unique(np.sign(gradient_chunk))) > 1: - continue - streamline = np.split(streamline, - pos_jump)[max_pos] - - final_streamlines.append(streamline) - final_label.append(curr_labels) - - final_streamlines = ArraySequence(final_streamlines) - final_labels = ArraySequence(final_label) - - indices = np.array(np.nonzero(binary_bundle), dtype=int).T - labels = min_dist_to_centroid(indices, - final_streamlines._data, - pre_computed_labels=final_labels._data) - labels_map = np.zeros(binary_bundle.shape, dtype=np.uint16) - labels_map[np.where(binary_bundle)] = labels - print('Clean Up time', time()-t0) - t0 = time() - - barycenters = compute_labels_map_barycenters(labels_map, - euclidian=args.new_labelling, - nb_pts=args.nb_pts) - - isnan = np.isnan(barycenters).all(axis=1) - # These two return the labels (not the indices) - head = np.argmax(~isnan) + 1 - tail = len(isnan) - np.argmax(~isnan[::-1]) - - distance_map = np.zeros(binary_bundle.shape, dtype=float) - barycenter_strs = [barycenters[head-1:tail]] - barycenter_bin = compute_tract_counts_map(barycenter_strs, - sft_centroid.dimensions) - barycenter_bin[barycenter_bin > 0] = 1 - for label in range(head, tail+1): - mask = np.zeros(labels_map.shape) - mask[labels_map == label] = 1 - labels_coords = np.array(np.where(mask)).T - if labels_coords.size == 0: - continue - - barycenter_bin_intersect = barycenter_bin * mask - barycenter_intersect_coords = np.array( - np.nonzero(barycenter_bin_intersect), dtype=int).T - - if barycenter_intersect_coords.size == 0: - continue - - if not args.new_labelling: - distances = np.linalg.norm(barycenter_intersect_coords[:, np.newaxis] - - labels_coords, axis=-1) - distance_map[labels_map == label] = np.min(distances, axis=0) - - else: - bundle_disjoint, num_labels = ndi.label(mask) - iterations = 0 - - while num_labels > 1: - mask = ndi.binary_dilation(mask) - bundle_disjoint, num_labels = ndi.label(mask) - iterations += 1 - - coords = [tuple(coord) for coord in barycenter_intersect_coords] - curr_dists = masked_manhattan_distance(mask, coords) - distance_map[labels_map == - label] = curr_dists[labels_map == label]+1 - print('Dijkstra time', time()-t0) - t0 = time() + # Initialize the scaler and the RBF sampler + scaler = MinMaxScaler(feature_range=(-1, 1)) + rbf_feature = RBFSampler(gamma=1.0, n_components=500, random_state=1) + + # Fit the scaler to the streamline data and transform it + scaler.fit(tmp_sft.streamlines._data) + scaled_streamline_data = scaler.transform(tmp_sft.streamlines._data) + + # Fit the RBFSampler to the scaled data and transform it + rbf_feature.fit(scaled_streamline_data) + features = rbf_feature.transform(scaled_streamline_data) + + # Initialize and fit the SGDClassifier with log loss + sgd_clf = SGDClassifier(loss='log_loss', max_iter=10000, tol=1e-4, + alpha=0.0001, random_state=1, + n_jobs=min(args.nb_pts, args.nbr_processes)) + sgd_clf.fit(X=features, y=labels) + + # Scale the coordinates of the voxels and transform with RBFSampler + voxel_coords = np.array(np.where(binary_map)).T + scaled_voxel_coords = scaler.transform(voxel_coords) + transformed_voxel_coords = rbf_feature.transform(scaled_voxel_coords) + + # Predict the labels for the voxels + labels = sgd_clf.predict(X=transformed_voxel_coords) + ### + print('SVC time for {}'.format(args.out_dir), time()-t0) + + labels_map = np.zeros(binary_map.shape, dtype=np.uint16) + labels_map[np.where(binary_map)] = labels + density_map = np.sum(density_list, axis=0) * binary_map + labels_map = weighted_vote_median_filter(labels_map, density_map) + distance_map = compute_distance_map(labels_map, binary_map, + args.new_labelling, args.nb_pts) cmap = get_colormap(args.colormap) for i, sft in enumerate(sft_list): @@ -326,8 +251,6 @@ def main(): save_tractogram(new_sft, os.path.join(sub_out_dir, "{}.trk".format(basename))) - print('Finish time', time()-t0) - t0 = time() if __name__ == '__main__': From 30a99b271e233424a2b94c43ef92b4f91fbdaf51 Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 7 Nov 2023 14:30:33 -0500 Subject: [PATCH 06/14] ALl but cleaning --- scilpy/image/labels.py | 71 ------------------- scilpy/tractanalysis/distance_to_centroid.py | 51 ++++++------- .../scil_compute_bundle_voxel_label_map.py | 4 +- 3 files changed, 23 insertions(+), 103 deletions(-) diff --git a/scilpy/image/labels.py b/scilpy/image/labels.py index f6dbb82a4..5c160e5ed 100644 --- a/scilpy/image/labels.py +++ b/scilpy/image/labels.py @@ -297,74 +297,3 @@ def dilate_labels(data, vox_size, distance, nbr_processes, data = data.reshape(img_shape) return data - - -def weighted_vote_median_filter(labels, density): - """ - Apply a weighted median voting filter on a 3D array of labels using the density and - distance from the center voxel as weights. - - Parameters: - labels (numpy.ndarray): A 3D numpy array of labels. - density (numpy.ndarray): A 3D numpy array of the same shape as labels, representing density/probability. - - Returns: - numpy.ndarray: A 3D numpy array with the filtered labels. - """ - # Precompute the 3x3x3 distance kernel - kernel_size = 3 - pad_width = kernel_size // 2 - - density = density.astype(float) / np.max(density) - - # Generate distances for a 3x3x3 kernel - x, y, z = np.indices((kernel_size, kernel_size, kernel_size)) - pad_width - distances = np.sqrt(x**2 + y**2 + z**2) - weights_distance = 1 / (1 + distances) - - # Pad the labels and density arrays - pad_width = kernel_size // 2 - padded_labels = np.pad(labels, pad_width, mode='constant', - constant_values=0) - padded_density = np.pad(density, pad_width, mode='constant', - constant_values=0) - - # Create an array to hold the new labels - new_labels = np.zeros_like(padded_labels) - - # Iterate over the 3D indices of the original labels array - for ind in np.argwhere(padded_labels > 0): - x, y, z = ind - # Extract the 3x3x3 cube around the current voxel - cube_labels = padded_labels[x-pad_width:x + pad_width + 1, - y-pad_width:y + pad_width + 1, - z-pad_width:z + pad_width + 1] - cube_density = padded_density[x-pad_width:x + pad_width + 1, - y-pad_width:y + pad_width + 1, - z-pad_width:z + pad_width + 1] - - # Compute weights for each label based on density and distance - weights = cube_density * weights_distance - - # Flatten the cube arrays to use in weighted voting - flat_labels = cube_labels.flatten() - flat_weights = weights.flatten() - - # Calculate the weighted count for each label - unique_labels = np.unique(flat_labels) - label_weights = np.zeros_like(unique_labels, dtype=float) - for i, label in enumerate(unique_labels): - label_weights[i] = np.sum(flat_weights[flat_labels == label]) - - # Select the label with the highest weighted count - selected_label = unique_labels[np.argmax(label_weights)] - if np.abs(float(selected_label) - float(padded_labels[x, y, z])) > 1 or \ - selected_label == 0: - new_labels[x, y, z] = padded_labels[x, y, z] - else: - new_labels[x, y, z] = selected_label - new_labels[x, y, z] = unique_labels[np.argmax(label_weights)] - - return new_labels[pad_width:-pad_width, - pad_width:-pad_width, - pad_width:-pad_width] diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index 23e4a33b6..d99e6905a 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +import scipy.ndimage as ndi from nibabel.streamlines.array_sequence import ArraySequence import heapq @@ -27,34 +29,27 @@ def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, def associate_labels(target_sft, source_sft, nb_pts=20): # KDTree for the target streamlines - target_kdtree = KDTree(target_sft.streamlines._data) - - distances, _ = target_kdtree.query(source_sft.streamlines._data, k=1, - distance_upper_bound=5) - valid_points = distances != np.inf - - # Find the first and last indices of non-infinite values - if valid_points.any(): - valid_points = np.mod(np.flatnonzero(valid_points), nb_pts) - labels, count = np.unique(valid_points, return_counts=True) - count = count / np.sum(count) - count[count < 1.0 / (nb_pts*1.5)] = np.NaN - valid_indices = np.where(~np.isnan(count))[0] - - # Find the first and last non-NaN indices - head = labels[valid_indices[0]] + 1 - tail = labels[valid_indices[-1]] + 1 - - labels = [] - for i in range(len(target_sft)): - streamline = target_sft.streamlines[i] + source_kdtree = KDTree(source_sft.streamlines._data) + final_labels = np.zeros(target_sft.streamlines._data.shape[0], + dtype=float) + curr_ind = 0 + for streamline in target_sft.streamlines: + distances, ids = source_kdtree.query(streamline, k=1) + + valid_points = distances != np.inf + + curr_labels = np.mod(ids[valid_points], nb_pts) + 1 + + head = np.min(curr_labels) + tail = np.max(curr_labels) + lengths = np.insert(length(streamline, along=True), 0, 0)[::-1] lengths = (lengths / np.max(lengths)) * \ (head - tail) + tail + final_labels[curr_ind:curr_ind+len(lengths)] = lengths + curr_ind += len(lengths) - labels = np.concatenate((labels, lengths)) - - return np.round(labels), head, tail + return np.round(final_labels), head, tail def find_medoid(points): @@ -153,9 +148,6 @@ def masked_manhattan_distance(mask, target_positions): return distances -import numpy as np -import scipy.ndimage as ndi -from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): """ Computes the distance map for each label in the labels_map. @@ -179,7 +171,8 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): distance_map = np.zeros(binary_map.shape, dtype=float) barycenter_strs = [barycenters[head-1:tail]] - barycenter_bin = compute_tract_counts_map(barycenter_strs, binary_map.shape) + barycenter_bin = compute_tract_counts_map( + barycenter_strs, binary_map.shape) barycenter_bin[barycenter_bin > 0] = 1 for label in range(head, tail+1): @@ -206,5 +199,5 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): curr_dists = masked_manhattan_distance(binary_map, coords) distance_map[labels_map == label] = \ curr_dists[labels_map == label] - + return distance_map diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index c9e2a0d29..9b9c9bb12 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -203,13 +203,11 @@ def main(): # Predict the labels for the voxels labels = sgd_clf.predict(X=transformed_voxel_coords) - ### + print('SVC time for {}'.format(args.out_dir), time()-t0) labels_map = np.zeros(binary_map.shape, dtype=np.uint16) labels_map[np.where(binary_map)] = labels - density_map = np.sum(density_list, axis=0) * binary_map - labels_map = weighted_vote_median_filter(labels_map, density_map) distance_map = compute_distance_map(labels_map, binary_map, args.new_labelling, args.nb_pts) From e5bb8d4124c30b1461b97822a0389556689d1eb4 Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 7 Nov 2023 15:51:42 -0500 Subject: [PATCH 07/14] Ready to showcase version --- scilpy/image/labels.py | 2 - scilpy/tractanalysis/distance_to_centroid.py | 128 +++++++++++++++--- .../scil_compute_bundle_voxel_label_map.py | 38 ++++-- 3 files changed, 137 insertions(+), 31 deletions(-) diff --git a/scilpy/image/labels.py b/scilpy/image/labels.py index 5c160e5ed..4a48b5e82 100644 --- a/scilpy/image/labels.py +++ b/scilpy/image/labels.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import time -from scipy.ndimage import generic_filter import inspect import logging import os diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index d99e6905a..302387a51 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- - -from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -import scipy.ndimage as ndi -from nibabel.streamlines.array_sequence import ArraySequence import heapq from dipy.tracking.metrics import length +from nibabel.streamlines.array_sequence import ArraySequence import numpy as np +import scipy.ndimage as ndi from scipy.spatial import KDTree from scipy.spatial.distance import pdist, squareform +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map + def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, pre_computed_labels=None): @@ -34,7 +34,8 @@ def associate_labels(target_sft, source_sft, nb_pts=20): dtype=float) curr_ind = 0 for streamline in target_sft.streamlines: - distances, ids = source_kdtree.query(streamline, k=1) + distances, ids = source_kdtree.query(streamline, + k=max(1, nb_pts // 5)) valid_points = distances != np.inf @@ -46,6 +47,7 @@ def associate_labels(target_sft, source_sft, nb_pts=20): lengths = np.insert(length(streamline, along=True), 0, 0)[::-1] lengths = (lengths / np.max(lengths)) * \ (head - tail) + tail + final_labels[curr_ind:curr_ind+len(lengths)] = lengths curr_ind += len(lengths) @@ -107,15 +109,16 @@ def compute_labels_map_barycenters(labels_map, is_euclidian=False, nb_pts=False) def masked_manhattan_distance(mask, target_positions): """ - Compute the Manhattan distance from every position in a mask to a set of positions, - without stepping out of the mask. + Compute the Manhattan distance from every position in a mask to a set of + positions, without stepping out of the mask. Parameters: mask (ndarray): A binary 3D array representing the mask. target_positions (list): A list of target positions within the mask. Returns: - ndarray: A 3D array of the same shape as the mask, containing the Manhattan distances. + ndarray: A 3D array of the same shape as the mask, containing the + Manhattan distances. """ # Initialize distance array with infinite values distances = np.full(mask.shape, np.inf) @@ -136,7 +139,9 @@ def masked_manhattan_distance(mask, target_positions): for dx, dy, dz in directions: nx, ny, nz = x + dx, y + dy, z + dz - if 0 <= nx < mask.shape[0] and 0 <= ny < mask.shape[1] and 0 <= nz < mask.shape[2]: + if 0 <= nx < mask.shape[0] and \ + 0 <= ny < mask.shape[1] and \ + 0 <= nz < mask.shape[2]: if mask[nx, ny, nz]: new_distance = current_distance + 1 @@ -153,26 +158,42 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): Computes the distance map for each label in the labels_map. Parameters: - labels_map (numpy.ndarray): A 3D array representing the labels map. - binary_map (numpy.ndarray): A 3D binary map used to calculate barycenter binary map. - new_labelling (bool): A flag to determine the type of distance calculation. - nb_pts (int): Number of points to use for computing barycenters. + labels_map (numpy.ndarray): + A 3D array representing the labels map. + binary_map (numpy.ndarray): + A 3D binary map used to calculate barycenter binary map. + new_labelling (bool): + A flag to determine the type of distance calculation. + nb_pts (int): + Number of points to use for computing barycenters. Returns: - numpy.ndarray: A 3D array representing the distance map. + numpy.ndarray: A 3D array representing the distance map. """ barycenters = compute_labels_map_barycenters(labels_map, is_euclidian=new_labelling, nb_pts=nb_pts) - + # If the first/last few points are NaN, remove them this indicates that the + # head/tail are not 1-NB_PTS isnan = np.isnan(barycenters).all(axis=1) head = np.argmax(~isnan) + 1 tail = len(isnan) - np.argmax(~isnan[::-1]) + # Identify the indices that do contain NaN values after/before head/tail + tmp_barycenter = barycenters[head-1:tail] + valid_indices = np.argwhere( + ~np.isnan(tmp_barycenter).any(axis=1)).flatten() + valid_data = tmp_barycenter[valid_indices] + interpolated_data = np.array( + [np.interp(np.arange(len(tmp_barycenter)), + valid_indices, + valid_data[:, i]) for i in range(tmp_barycenter.shape[1])]).T + barycenters[head-1:tail] = interpolated_data + distance_map = np.zeros(binary_map.shape, dtype=float) barycenter_strs = [barycenters[head-1:tail]] - barycenter_bin = compute_tract_counts_map( - barycenter_strs, binary_map.shape) + barycenter_bin = compute_tract_counts_map(barycenter_strs, + binary_map.shape) barycenter_bin[barycenter_bin > 0] = 1 for label in range(head, tail+1): @@ -201,3 +222,76 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): curr_dists[labels_map == label] return distance_map + + +def correct_labels_jump(labels_map, streamlines, nb_pts): + labels_data = ndi.map_coordinates(labels_map, streamlines._data.T - 0.5, + order=0) + + # It is not allowed that labels jumps labels for consistency + # Streamlines should have continous labels + final_streamlines = [] + final_labels = [] + curr_ind = 0 + for streamline in streamlines: + next_ind = curr_ind + len(streamline) + curr_labels = labels_data[curr_ind:next_ind] + curr_ind = next_ind + + # Flip streamlines so the labels increase (facilitate if/else) + # Should always be ordered in nextflow pipeline + gradient = np.gradient(curr_labels) + if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): + streamline = streamline[::-1] + curr_labels = curr_labels[::-1] + + # Find jumps, cut them and find the longest + gradient = np.ediff1d(curr_labels) + max_jump = max(nb_pts // 2, 1) + if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: + pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 + split_chunk = np.split(curr_labels, + pos_jump) + + max_len = 0 + max_pos = 0 + for j, chunk in enumerate(split_chunk): + if len(chunk) > max_len: + max_len = len(chunk) + max_pos = j + + curr_labels = split_chunk[max_pos] + gradient_chunk = np.ediff1d(chunk) + if len(np.unique(np.sign(gradient_chunk))) > 1: + continue + streamline = np.split(streamline, + pos_jump)[max_pos] + + final_streamlines.append(streamline) + final_labels.append(curr_labels) + + # Once the streamlines abnormalities are corrected, we can + # recompute the labels map with the new streamlines/labels + final_labels = ArraySequence(final_labels) + final_streamlines = ArraySequence(final_streamlines) + + kd_tree = KDTree(final_streamlines._data) + indices = np.array(np.nonzero(labels_map), dtype=int).T + labels_map = np.zeros(labels_map.shape, dtype=np.int16) + + for ind in indices: + neighbor_dists, neighbor_ids = kd_tree.query(ind, k=5) + + if not len(neighbor_ids): + continue + + labels_val = final_labels._data[neighbor_ids] + sum_dists_vox = np.sum(neighbor_dists) + weights = np.exp(-neighbor_dists / sum_dists_vox) + + vote = np.bincount(labels_val, weights=weights) + total = np.arange(np.amax(labels_val+1)) + winner = total[np.argmax(vote)] + labels_map[ind[0], ind[1], ind[2]] = winner + + return labels_map diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 9b9c9bb12..861b5bed2 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -8,27 +8,23 @@ The number of labels will be the same as the centroid's number of points. """ -from sklearn.preprocessing import MinMaxScaler -from sklearn.linear_model import SGDClassifier -from sklearn.svm import LinearSVC -from sklearn.kernel_approximation import RBFSampler -from sklearn.svm import SVC -from collections import defaultdict from time import time import argparse import os -# from dipy.align.streamlinear import BundleMinDistanceMetric, StreamlineLinearRegistration + from dipy.io.streamline import save_tractogram -from dipy.io.stateful_tractogram import StatefulTractogram, set_sft_logger_level, Space, Origin +from dipy.io.stateful_tractogram import (StatefulTractogram, + set_sft_logger_level) from dipy.io.utils import is_header_compatible -import matplotlib.pyplot as plt import nibabel as nib from nibabel.streamlines.array_sequence import ArraySequence import numpy as np import scipy.ndimage as ndi +from sklearn.preprocessing import MinMaxScaler +from sklearn.linear_model import SGDClassifier +from sklearn.kernel_approximation import RBFSampler -from scilpy.image.labels import weighted_vote_median_filter from scilpy.image.volume_math import correlation from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, @@ -39,7 +35,8 @@ from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractanalysis.distance_to_centroid import (min_dist_to_centroid, compute_distance_map, - associate_labels) + associate_labels, + correct_labels_jump) from scilpy.tractograms.streamline_and_mask_operations import \ cut_outside_of_mask_streamlines from scilpy.tractograms.streamline_operations import \ @@ -160,13 +157,17 @@ def main(): args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \ else args.nb_pts + # This allows to have a more uniform (in size) first and last labels + if args.new_labelling: + args.nb_pts += 2 + sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) # Select 2000 elements from the SFTs random_indices = np.random.choice(len(concat_sft), min(len(concat_sft), 2000), replace=False) - tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], 2.0) + tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], 1.0) t0 = time() if not args.new_labelling: @@ -208,8 +209,21 @@ def main(): labels_map = np.zeros(binary_map.shape, dtype=np.uint16) labels_map[np.where(binary_map)] = labels + t0 = time() + + # Correct the labels to have a more uniform size + if args.new_labelling: + labels_map[labels_map == args.nb_pts] = args.nb_pts - 1 + labels_map[labels_map == 1] = 2 + labels_map[labels_map > 0] -= 1 + labels_map = correct_labels_jump(labels_map, concat_sft.streamlines, + args.nb_pts) + print('Correct labels time for {}'.format(args.out_dir), time()-t0) + + t0 = time() distance_map = compute_distance_map(labels_map, binary_map, args.new_labelling, args.nb_pts) + print('Distance map time for {}'.format(args.out_dir), time()-t0) cmap = get_colormap(args.colormap) for i, sft in enumerate(sft_list): From c5333db3b1c74b99b3a10aaa22ef5f44fe302a7b Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 8 Nov 2023 09:07:44 -0500 Subject: [PATCH 08/14] Remove last 2 --- scripts/scil_compute_bundle_voxel_label_map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 861b5bed2..9d0ce1026 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -216,6 +216,7 @@ def main(): labels_map[labels_map == args.nb_pts] = args.nb_pts - 1 labels_map[labels_map == 1] = 2 labels_map[labels_map > 0] -= 1 + args.nb_pts -= 2 labels_map = correct_labels_jump(labels_map, concat_sft.streamlines, args.nb_pts) print('Correct labels time for {}'.format(args.out_dir), time()-t0) From 77f2d54dad143a897226e1b2584f7bc05bd83be7 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 8 Nov 2023 14:53:34 -0500 Subject: [PATCH 09/14] Robust testing, docstring and logging --- scilpy/tractanalysis/distance_to_centroid.py | 20 +- .../scil_compute_bundle_voxel_label_map.py | 187 +++++++++++++----- .../test_compute_bundle_voxel_label_map.py | 17 +- 3 files changed, 169 insertions(+), 55 deletions(-) diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index 302387a51..5dbc25796 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -54,7 +54,7 @@ def associate_labels(target_sft, source_sft, nb_pts=20): return np.round(final_labels), head, tail -def find_medoid(points): +def find_medoid(points, max_points=10000): """ Find the medoid among a set of points. @@ -64,6 +64,11 @@ def find_medoid(points): Returns: ndarray: Coordinates of the medoid. """ + if len(points) > max_points: + selected_indices = np.random.choice(len(points), max_points, + replace=False) + points = points[selected_indices] + distance_matrix = squareform(pdist(points)) medoid_idx = np.argmin(distance_matrix.sum(axis=1)) return points[medoid_idx] @@ -92,14 +97,15 @@ def compute_labels_map_barycenters(labels_map, is_euclidian=False, nb_pts=False) mask = np.zeros_like(labels_map) mask[labels_map == label] = 1 mask_coords = np.argwhere(mask) - if is_euclidian: barycenter = np.mean(mask_coords, axis=0) else: barycenter = find_medoid(mask_coords) + # If the barycenter is not in the mask, find the closest point if labels_map[tuple(barycenter.astype(int))] != label: tree = KDTree(indices) _, ind = tree.query(barycenter, k=1) + del tree barycenter = indices[ind] barycenters[label - 1] = barycenter @@ -153,7 +159,7 @@ def masked_manhattan_distance(mask, target_positions): return distances -def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): +def compute_distance_map(labels_map, binary_map, is_euclidian, nb_pts): """ Computes the distance map for each label in the labels_map. @@ -162,7 +168,7 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): A 3D array representing the labels map. binary_map (numpy.ndarray): A 3D binary map used to calculate barycenter binary map. - new_labelling (bool): + hyperplane (bool): A flag to determine the type of distance calculation. nb_pts (int): Number of points to use for computing barycenters. @@ -171,7 +177,7 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): numpy.ndarray: A 3D array representing the distance map. """ barycenters = compute_labels_map_barycenters(labels_map, - is_euclidian=new_labelling, + is_euclidian=is_euclidian, nb_pts=nb_pts) # If the first/last few points are NaN, remove them this indicates that the # head/tail are not 1-NB_PTS @@ -210,7 +216,7 @@ def compute_distance_map(labels_map, binary_map, new_labelling, nb_pts): if barycenter_intersect_coords.size == 0: continue - if not new_labelling: + if is_euclidian: distances = np.linalg.norm( barycenter_intersect_coords[:, np.newaxis] - labels_coords, axis=-1) @@ -277,7 +283,7 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): kd_tree = KDTree(final_streamlines._data) indices = np.array(np.nonzero(labels_map), dtype=int).T - labels_map = np.zeros(labels_map.shape, dtype=np.int16) + labels_map = np.zeros(labels_map.shape, dtype=np.uint16) for ind in indices: neighbor_dists, neighbor_ids = kd_tree.query(ind, k=5) diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index 9d0ce1026..d1ec88b7c 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -2,20 +2,57 @@ # -*- coding: utf-8 -*- """ -Compute label image (Nifti) from bundle and centroid. -Each voxel will have the label of its nearest centroid point. - -The number of labels will be the same as the centroid's number of points. +Compute label image (Nifti) from bundle(s) and centroid(s). +Each voxel will have a label that represents its position along the bundle. + +The number of labels will be the same as the centroid's number of points, +unless specified otherwise. + +# Single bundle case + This script takes as input a bundle file, a centroid streamline corresponding + to the bundle. It computes label images, where each voxel is assigned the + label of its nearest centroid point. The resulting images represent the + labels, distances between the bundle and centroid. + +# Multiple bundle case + When providing multiple (co-registered) bundles, the script will compute a + correlation map, which shows the spatial correlation between density maps + It will also compute the labels maps for all bundles at once, ensuring + that the labels are spatial consistent between bundles. + +# Hyperplane method + The default is to use the euclidian/centerline method, which is fast and + works well for most cases. + The hyperplane method allows for more complex shapes and to split the bundles + into subsection that follow the geometry of each kind of bundle. + However, this method is slower and requires extra quality control to ensure + that the labels are correct. This method requires a centroid file that + contains multiple streamlines. + This method is based on the following paper [1], but was heavily modified + and adapted to work more robustly across datasets. + +# Manhatan distance + The default distance (to barycenter of label) is the euclidian distance. + The manhattan distance can be used instead to compute the distance to the + barycenter without stepping out of the mask. + +Colormap selection affects tractograms coloring for visualization only. +For detailed information on usage and parameters, please refer to the script's +documentation. + +Author: +------- +Francois Rheault +francois.m.rheault@usherbrooke.ca """ -from time import time import argparse +import logging import os - +import time from dipy.io.streamline import save_tractogram -from dipy.io.stateful_tractogram import (StatefulTractogram, - set_sft_logger_level) +from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.utils import is_header_compatible import nibabel as nib from nibabel.streamlines.array_sequence import ArraySequence @@ -30,6 +67,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_reference_arg, + add_verbose_arg, assert_inputs_exist, assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map @@ -45,9 +83,17 @@ from scilpy.viz.utils import get_colormap +EPILOG = """ +[1] Neher, Peter, Dusan Hirjak, and Klaus Maier-Hein. "Radiomic tractometry: a + rich and tract-specific class of imaging biomarkers for neuroscience and + medical applications." Research Square (2023). +""" + + def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, + epilog=EPILOG, formatter_class=argparse.RawTextHelpFormatter) p.add_argument('in_bundles', nargs='+', @@ -63,23 +109,30 @@ def _build_arg_parser(): p.add_argument('--colormap', default='jet', help='Select the colormap for colored trk (data_per_point) ' '[%(default)s].') - p.add_argument('--new_labelling', action='store_true', - help='Use the new labelling method (multi-centroids).') + p.add_argument('--hyperplane', action='store_true', + help='Use the hyperplane method (multi-centroids) instead ' + 'of the euclidian method (single-centroid).') + p.add_argument('--use_manhattan', action='store_true', + help='Use the manhattan distance instead of the euclidian ' + 'distance.') p.add_argument('--skip_uniformize', action='store_true', help='Skip uniformization of the bundles orientation.') + p.add_argument('--correlation_thr', type=float, default=0.1, + help='Threshold for the correlation map. Only for multi ' + 'bundle case. [%(default)s]') add_processes_arg(p) add_reference_arg(p) + add_verbose_arg(p) add_overwrite_arg(p) return p def main(): - t0 = time() parser = _build_arg_parser() args = parser.parse_args() - set_sft_logger_level('ERROR') + # set_sft_logger_level('ERROR') assert_inputs_exist(parser, args.in_bundles + [args.in_centroid], optional=args.reference) assert_output_dirs_exist_and_empty(parser, args, args.out_dir) @@ -87,12 +140,18 @@ def main(): sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) + if args.verbose: + logging.getLogger().setLevel(logging.INFO) + + # When doing longitudinal data, the split in subsection can be done + # on all the bundles at once. Needs to be co-registered. + timer = time.time() sft_list = [] for filename in args.in_bundles: sft = load_tractogram_with_reference(parser, args, filename) if not len(sft.streamlines): raise IOError('Empty bundle file {}. ' - 'Skipping'.format(args.in_bundle)) + 'Skipping'.format(args.in_bundles)) if not args.skip_uniformize: uniformize_bundle_sft(sft, ref_bundle=sft_centroid) sft.to_vox() @@ -104,27 +163,33 @@ def main(): parser.error('Header of {} and {} are not compatible'.format( args.in_bundles[0], filename)) - # Perform after the uniformization sft_centroid.to_vox() sft_centroid.to_corner() + logging.info('Loaded {} bundle(s) in {} seconds.'.format( + len(args.in_bundles), round(time.time() - timer, 3))) - t0 = time() density_list = [] binary_list = [] + timer = time.time() for sft in sft_list: density = compute_tract_counts_map(sft.streamlines, sft.dimensions).astype(float) - binary = np.zeros(sft.dimensions) + binary = np.zeros(sft.dimensions, dtype=np.uint8) binary[density > 0] = 1 binary_list.append(binary) density_list.append(density) if not is_header_compatible(sft_centroid, sft_list[0]): raise IOError('{} and {}do not have a compatible header'.format( - args.in_centroid, args.in_bundle)) + args.in_centroid, args.in_bundles)) + logging.info('Computed density and binary map(s) in {}.'.format( + round(time.time() - timer, 3))) if len(density_list) > 1: + timer = time.time() corr_map = correlation(density_list, None) + logging.info('Computed correlation map in {} seconds.'.format( + round(time.time() - timer, 3))) else: corr_map = density_list[0].astype(float) corr_map[corr_map > 0] = 1 @@ -132,18 +197,18 @@ def main(): # Slightly cut the bundle at the edgge to clean up single streamline voxels # with no neighbor. Remove isolated voxels to keep a single 'blob' binary_map = np.zeros(corr_map.shape, dtype=bool) - binary_map[corr_map > 0.5] = 1 + binary_map[corr_map > args.correlation_thr] = 1 bundle_disjoint, _ = ndi.label(binary_map) unique, count = np.unique(bundle_disjoint, return_counts=True) val = unique[np.argmax(count[1:])+1] binary_map[bundle_disjoint != val] = 0 - corr_map = corr_map*binary_map - nib.save(nib.Nifti1Image(corr_map, sft_list[0].affine), + nib.save(nib.Nifti1Image(corr_map * binary_map, sft_list[0].affine), os.path.join(args.out_dir, 'correlation_map.nii.gz')) - # Chop off some streamlines + # A bundle must be contiguous, we can't have isolated voxels. + timer = time.time() concat_sft = StatefulTractogram.from_sft([], sft_list[0]) concat_sft.to_vox() concat_sft.to_corner() @@ -152,80 +217,103 @@ def main(): binary_map) if len(sft_list[i]): concat_sft += sft_list[i] + logging.info('Chop bundle(s) in {} seconds.'.format( + round(time.time() - timer, 3))) - t0 = time() args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \ else args.nb_pts # This allows to have a more uniform (in size) first and last labels - if args.new_labelling: + if args.hyperplane: args.nb_pts += 2 sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) - # Select 2000 elements from the SFTs - random_indices = np.random.choice(len(concat_sft), - min(len(concat_sft), 2000), - replace=False) - tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], 1.0) - - t0 = time() - if not args.new_labelling: + timer = time.time() + if not args.hyperplane: indices = np.array(np.nonzero(binary_map), dtype=int).T labels = min_dist_to_centroid(indices, sft_centroid[0].streamlines._data, nb_pts=args.nb_pts) + logging.info('Computed labels using the euclidian method ' + 'in {} seconds'.format(round(time.time() - timer, 3))) else: + logging.info('Computing Labels using the hyperplane method.\n' + '\tThis can take a while...') + # Select 2000 elements from the SFTs to train the classifier + random_indices = np.random.choice(len(concat_sft), + min(len(concat_sft), 2000), + replace=False) + tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], + 1.0) + # Associate the labels to the streamlines using the centroids as + # reference (to handle shorter bundles due to missing data) + mini_timer = time.time() labels, _, _ = associate_labels(tmp_sft, sft_centroid, args.nb_pts) + print('\tAssociated labels to_centroids in {} seconds'.format( + round(time.time() - mini_timer, 3))) # Initialize the scaler and the RBF sampler scaler = MinMaxScaler(feature_range=(-1, 1)) - rbf_feature = RBFSampler(gamma=1.0, n_components=500, random_state=1) + rbf_feature = RBFSampler(gamma=1.0, n_components=1000, random_state=1) # Fit the scaler to the streamline data and transform it + mini_timer = time.time() scaler.fit(tmp_sft.streamlines._data) scaled_streamline_data = scaler.transform(tmp_sft.streamlines._data) # Fit the RBFSampler to the scaled data and transform it rbf_feature.fit(scaled_streamline_data) features = rbf_feature.transform(scaled_streamline_data) + print('\tScaler and RBF kernel approximation in {} seconds'.format( + round(time.time() - mini_timer, 3))) # Initialize and fit the SGDClassifier with log loss - sgd_clf = SGDClassifier(loss='log_loss', max_iter=10000, tol=1e-4, + mini_timer = time.time() + sgd_clf = SGDClassifier(loss='log_loss', max_iter=10000, tol=1e-5, alpha=0.0001, random_state=1, n_jobs=min(args.nb_pts, args.nbr_processes)) sgd_clf.fit(X=features, y=labels) + print('\tSGDClassifier fit of training data in {} seconds'.format( + round(time.time() - mini_timer, 3))) # Scale the coordinates of the voxels and transform with RBFSampler + mini_timer = time.time() voxel_coords = np.array(np.where(binary_map)).T scaled_voxel_coords = scaler.transform(voxel_coords) transformed_voxel_coords = rbf_feature.transform(scaled_voxel_coords) # Predict the labels for the voxels labels = sgd_clf.predict(X=transformed_voxel_coords) + print('\tSGDClassifier prediction of labels in {} seconds'.format( + round(time.time() - mini_timer, 3))) - print('SVC time for {}'.format(args.out_dir), time()-t0) - + logging.info('Computed labels using the hyperplane method ' + 'in {} seconds'.format(round(time.time() - timer, 3))) labels_map = np.zeros(binary_map.shape, dtype=np.uint16) labels_map[np.where(binary_map)] = labels - t0 = time() - # Correct the labels to have a more uniform size - if args.new_labelling: + # Correct the hyperplane labels to have a more uniform size + if args.hyperplane: labels_map[labels_map == args.nb_pts] = args.nb_pts - 1 labels_map[labels_map == 1] = 2 labels_map[labels_map > 0] -= 1 args.nb_pts -= 2 + + timer = time.time() labels_map = correct_labels_jump(labels_map, concat_sft.streamlines, args.nb_pts) - print('Correct labels time for {}'.format(args.out_dir), time()-t0) + logging.info('Corrected labels jump in {} seconds'.format( + round(time.time() - timer, 3))) - t0 = time() + timer = time.time() distance_map = compute_distance_map(labels_map, binary_map, - args.new_labelling, args.nb_pts) - print('Distance map time for {}'.format(args.out_dir), time()-t0) + args.use_manhattan, args.nb_pts) + logging.info('Computed distance map in {} seconds'.format( + round(time.time() - timer, 3))) + timer = time.time() cmap = get_colormap(args.colormap) for i, sft in enumerate(sft_list): if len(sft_list) > 1: @@ -240,8 +328,8 @@ def main(): # Dictionary to hold the data for each type data_dict = {'labels': labels_map.astype(np.uint16), - 'distance': distance_map.astype(float), - 'correlation': corr_map.astype(float)} + 'distance': distance_map.astype(np.float32), + 'correlation': corr_map.astype(np.float32)} # Iterate through each type to save the files for basename, map in data_dict.items(): @@ -255,8 +343,13 @@ def main(): tmp_data = ndi.map_coordinates( map, sft.streamlines._data.T - 0.5, order=0) - max_val = args.nb_pts if basename == 'labels' else np.max( - tmp_data) + if basename == 'labels': + max_val = args.nb_pts + elif basename == 'correlation': + max_val = 1 + else: + max_val = np.max(tmp_data) + max_val = args.nb_pts new_sft.data_per_point['color']._data = cmap( tmp_data / max_val)[:, 0:3] * 255 @@ -264,6 +357,8 @@ def main(): save_tractogram(new_sft, os.path.join(sub_out_dir, "{}.trk".format(basename))) + logging.info('Saved all data to {} in {} seconds'.format( + args.out_dir, round(time.time() - timer, 3))) if __name__ == '__main__': diff --git a/scripts/tests/test_compute_bundle_voxel_label_map.py b/scripts/tests/test_compute_bundle_voxel_label_map.py index 0f50fc14d..d7fe55fd8 100644 --- a/scripts/tests/test_compute_bundle_voxel_label_map.py +++ b/scripts/tests/test_compute_bundle_voxel_label_map.py @@ -17,7 +17,7 @@ def test_help_option(script_runner): assert ret.success -def test_execution_tractometry(script_runner): +def test_execution_tractometry_euclidian(script_runner): os.chdir(os.path.expanduser(tmp_dir.name)) in_bundle = os.path.join(get_home(), 'tractometry', 'IFGWM.trk') @@ -25,6 +25,19 @@ def test_execution_tractometry(script_runner): 'IFGWM_uni_c_10.trk') ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', in_bundle, in_centroid, - 'results_dir/', + 'results_euc/', '--colormap', 'viridis') assert ret.success + +def test_execution_tractometry_hyperplane(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(get_home(), 'tractometry', + 'IFGWM.trk') + in_centroid = os.path.join(get_home(), 'tractometry', + 'IFGWM_uni_c_10.trk') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', + in_bundle, in_centroid, + 'results_man/', + '--colormap', 'viridis', + '--hyperplane', '--use_manhattan') + assert ret.success From 52552df8b630b27aa6b2312d3a4824614af5bd3c Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 24 Nov 2023 15:34:02 -0500 Subject: [PATCH 10/14] Fix scaler --- scilpy/tractanalysis/distance_to_centroid.py | 144 ++++++++++++------ scilpy/utils/streamlines.py | 4 +- .../scil_compute_bundle_voxel_label_map.py | 81 +++++----- 3 files changed, 139 insertions(+), 90 deletions(-) diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index 5dbc25796..d95ec28a9 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from scilpy.tractograms.streamline_and_mask_operations import get_head_tail_density_maps import heapq from dipy.tracking.metrics import length @@ -9,49 +10,90 @@ from scipy.spatial.distance import pdist, squareform from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from sklearn.preprocessing import MinMaxScaler +from sklearn.linear_model import SGDClassifier +from sklearn.svm import SVC +from sklearn.kernel_approximation import RBFSampler, Nystroem -def min_dist_to_centroid(target_pts, source_pts, nb_pts=None, - pre_computed_labels=None): - if nb_pts is None and pre_computed_labels is None: - raise ValueError('Either nb_pts or labels must be provided.') - +def min_dist_to_centroid(target_pts, source_pts, nb_pts=None): tree = KDTree(source_pts, copy_data=True) _, labels = tree.query(target_pts, k=1) - - if pre_computed_labels is None: - labels = np.mod(labels, nb_pts) + 1 - else: - labels = pre_computed_labels[labels] + labels = np.mod(labels, nb_pts) + 1 return labels.astype(np.uint16) -def associate_labels(target_sft, source_sft, nb_pts=20): - # KDTree for the target streamlines - source_kdtree = KDTree(source_sft.streamlines._data) - final_labels = np.zeros(target_sft.streamlines._data.shape[0], - dtype=float) +def associate_labels(target_sft, source_sft, nb_pts=20, sample_set=False, + sample_size=None): curr_ind = 0 - for streamline in target_sft.streamlines: - distances, ids = source_kdtree.query(streamline, - k=max(1, nb_pts // 5)) - - valid_points = distances != np.inf - - curr_labels = np.mod(ids[valid_points], nb_pts) + 1 - - head = np.min(curr_labels) - tail = np.max(curr_labels) - - lengths = np.insert(length(streamline, along=True), 0, 0)[::-1] - lengths = (lengths / np.max(lengths)) * \ - (head - tail) + tail - - final_labels[curr_ind:curr_ind+len(lengths)] = lengths - curr_ind += len(lengths) + source_labels = np.zeros(source_sft.streamlines._data.shape[0], + dtype=float) + for streamline in source_sft.streamlines: + curr_length = np.insert(length(streamline, along=True), 0, 0) + curr_labels = np.interp(curr_length, + [0, curr_length[-1]], + [1, nb_pts]) + curr_labels = np.round(curr_labels) + source_labels[curr_ind:curr_ind+len(streamline)] = curr_labels + curr_ind += len(streamline) + + scaler = MinMaxScaler(feature_range=(-1, 1)) + scaler.fit(target_sft.streamlines._data) + scaled_streamline_data = scaler.transform(source_sft.streamlines._data) + + svc = SVC(C=1.0, kernel='rbf', random_state=1) + svc.fit(X=scaled_streamline_data, y=source_labels) - return np.round(final_labels), head, tail + curr_ind = 0 + target_labels = np.zeros(target_sft.streamlines._data.shape[0], + dtype=float) + for streamline in target_sft.streamlines: + head_tail = [streamline[0], streamline[-1]] + scaled_head_tail_data = scaler.transform(head_tail) + head_tail_labels = svc.predict(X=scaled_head_tail_data) + curr_length = np.insert(length(streamline, along=True), 0, 0) + curr_labels = np.interp(curr_length, + [0, curr_length[-1]], + head_tail_labels) + target_labels[curr_ind:curr_ind+len(streamline)] = curr_labels + curr_ind += len(streamline) + + target_labels = np.round(target_labels).astype(int) + + if sample_set: + if sample_size is None: + sample_size = np.unique(target_labels, return_counts=True)[1].min() + + # Sort points by labels + sorted_indices = target_labels.argsort() + sorted_points = target_sft.streamlines._data[sorted_indices] + sorted_labels = target_labels[sorted_indices] + + # Find the start and end of each label + unique_labels, start_indices = np.unique( + sorted_labels, return_index=True) + end_indices = np.roll(start_indices, -1) + end_indices[-1] = len(target_labels) + + # Sample points and labels for each label + sampled_points = [] + sampled_labels = [] + for start, end, label in zip(start_indices, end_indices, unique_labels): + num_points = end - start + indices_to_sample = min(num_points, sample_size) + sampled_indices = np.random.choice( + np.arange(start, end), size=indices_to_sample, replace=False) + sampled_points.append(sorted_points[sampled_indices]) + sampled_labels.extend([label] * indices_to_sample) + + # Concatenate all sampled points + sampled_points = np.concatenate(sampled_points) + sampled_labels = np.array(sampled_labels) + + return sampled_labels, sampled_points + else: + return target_labels, target_sft.streamlines._data def find_medoid(points, max_points=10000): @@ -233,6 +275,8 @@ def compute_distance_map(labels_map, binary_map, is_euclidian, nb_pts): def correct_labels_jump(labels_map, streamlines, nb_pts): labels_data = ndi.map_coordinates(labels_map, streamlines._data.T - 0.5, order=0) + binary_map = np.zeros(labels_map.shape, dtype=np.uint8) + binary_map[labels_map > 0] = 1 # It is not allowed that labels jumps labels for consistency # Streamlines should have continous labels @@ -241,15 +285,18 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): curr_ind = 0 for streamline in streamlines: next_ind = curr_ind + len(streamline) - curr_labels = labels_data[curr_ind:next_ind] + curr_labels = labels_data[curr_ind:next_ind].astype(int) curr_ind = next_ind # Flip streamlines so the labels increase (facilitate if/else) # Should always be ordered in nextflow pipeline - gradient = np.gradient(curr_labels) + gradient = np.ediff1d(curr_labels) + + is_flip = False if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): streamline = streamline[::-1] curr_labels = curr_labels[::-1] + is_flip = True # Find jumps, cut them and find the longest gradient = np.ediff1d(curr_labels) @@ -273,6 +320,9 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): streamline = np.split(streamline, pos_jump)[max_pos] + if is_flip: + streamline = streamline[::-1] + curr_labels = curr_labels[::-1] final_streamlines.append(streamline) final_labels.append(curr_labels) @@ -286,18 +336,20 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): labels_map = np.zeros(labels_map.shape, dtype=np.uint16) for ind in indices: - neighbor_dists, neighbor_ids = kd_tree.query(ind, k=5) - + neighbor_ids = kd_tree.query_ball_point(ind, r=2.0) if not len(neighbor_ids): - continue + _, neighbor_ids = kd_tree.query(ind, k=5) + labels_val = np.median(final_labels._data[neighbor_ids]) - labels_val = final_labels._data[neighbor_ids] - sum_dists_vox = np.sum(neighbor_dists) - weights = np.exp(-neighbor_dists / sum_dists_vox) - vote = np.bincount(labels_val, weights=weights) - total = np.arange(np.amax(labels_val+1)) - winner = total[np.argmax(vote)] - labels_map[ind[0], ind[1], ind[2]] = winner + labels_map[ind[0], ind[1], ind[2]] = labels_val + # return labels_map + # To ensure spatial smoothness, we apply a gaussian filter on the labels + tmp_labels_map = np.zeros(labels_map.shape+(nb_pts+3,), dtype=float) + for i in np.unique(labels_map)[1:]: + tmp_labels_map[labels_map == i, i] = 1 + tmp_labels_map[..., i] = ndi.gaussian_filter(tmp_labels_map[..., i], + sigma=2) + labels_map = np.argmax(tmp_labels_map, axis=-1).astype(np.uint16) - return labels_map + return binary_map * labels_map diff --git a/scilpy/utils/streamlines.py b/scilpy/utils/streamlines.py index bfcdfaccf..4bfbbb5b1 100644 --- a/scilpy/utils/streamlines.py +++ b/scilpy/utils/streamlines.py @@ -58,8 +58,8 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): else: centroid = get_streamlines_centroid(sft.streamlines, 20)[0] endpoints_distance = np.abs(centroid[0] - centroid[-1]) - total_travel = np.abs( - np.sum(np.gradient(centroid, axis=0), axis=0)) + total_travel = np.sum( + np.abs(np.gradient(centroid, axis=0)), axis=0) # Reweigth the distance to the endpoints and the total travel # to avoid having a XYbundle oriented in the wrong direction diff --git a/scripts/scil_compute_bundle_voxel_label_map.py b/scripts/scil_compute_bundle_voxel_label_map.py index d1ec88b7c..c3101253b 100755 --- a/scripts/scil_compute_bundle_voxel_label_map.py +++ b/scripts/scil_compute_bundle_voxel_label_map.py @@ -60,12 +60,12 @@ import scipy.ndimage as ndi from sklearn.preprocessing import MinMaxScaler from sklearn.linear_model import SGDClassifier -from sklearn.kernel_approximation import RBFSampler +from sklearn.svm import SVC +from sklearn.kernel_approximation import RBFSampler, Nystroem from scilpy.image.volume_math import correlation from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, - add_processes_arg, add_reference_arg, add_verbose_arg, assert_inputs_exist, @@ -120,8 +120,10 @@ def _build_arg_parser(): p.add_argument('--correlation_thr', type=float, default=0.1, help='Threshold for the correlation map. Only for multi ' 'bundle case. [%(default)s]') + p.add_argument('--streamlines_thr', type=int, default=2, + help='Threshold for the minimum number of streamlines in a ' + 'voxel to be included [%(default)s].' ) - add_processes_arg(p) add_reference_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -175,7 +177,7 @@ def main(): density = compute_tract_counts_map(sft.streamlines, sft.dimensions).astype(float) binary = np.zeros(sft.dimensions, dtype=np.uint8) - binary[density > 0] = 1 + binary[density >= args.streamlines_thr] = 1 binary_list.append(binary) density_list.append(density) @@ -196,9 +198,10 @@ def main(): # Slightly cut the bundle at the edgge to clean up single streamline voxels # with no neighbor. Remove isolated voxels to keep a single 'blob' - binary_map = np.zeros(corr_map.shape, dtype=bool) - binary_map[corr_map > args.correlation_thr] = 1 + binary_map = np.max(binary_list, axis=0) + binary_map[corr_map < args.correlation_thr] = 0 + # TODO eliminate the bottom quartile of the blob bundle_disjoint, _ = ndi.label(binary_map) unique, count = np.unique(bundle_disjoint, return_counts=True) val = unique[np.argmax(count[1:])+1] @@ -224,8 +227,10 @@ def main(): else args.nb_pts # This allows to have a more uniform (in size) first and last labels - if args.hyperplane: + endpoints_extended = False + if args.hyperplane and args.nb_pts >= 5: args.nb_pts += 2 + endpoints_extended = True sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) @@ -249,61 +254,53 @@ def main(): # Associate the labels to the streamlines using the centroids as # reference (to handle shorter bundles due to missing data) mini_timer = time.time() - labels, _, _ = associate_labels(tmp_sft, sft_centroid, - args.nb_pts) - print('\tAssociated labels to_centroids in {} seconds'.format( - round(time.time() - mini_timer, 3))) - - # Initialize the scaler and the RBF sampler - scaler = MinMaxScaler(feature_range=(-1, 1)) - rbf_feature = RBFSampler(gamma=1.0, n_components=1000, random_state=1) + sample_size = np.count_nonzero(binary_map) // args.nb_pts + labels, points, = associate_labels(tmp_sft, sft_centroid, + args.nb_pts, sample_set=True, + sample_size=sample_size) - # Fit the scaler to the streamline data and transform it + logging.info('\tAssociated labels to centroids in {} seconds'.format( + round(time.time() - mini_timer, 3))) + + # Initialize the scaler mini_timer = time.time() - scaler.fit(tmp_sft.streamlines._data) - scaled_streamline_data = scaler.transform(tmp_sft.streamlines._data) + scaler = MinMaxScaler(feature_range=(-1, 1)) + scaler.fit(points) + scaled_streamline_data = scaler.transform(points) - # Fit the RBFSampler to the scaled data and transform it - rbf_feature.fit(scaled_streamline_data) - features = rbf_feature.transform(scaled_streamline_data) - print('\tScaler and RBF kernel approximation in {} seconds'.format( - round(time.time() - mini_timer, 3))) + svc = SVC(C=1.0, kernel='rbf', random_state=1) - # Initialize and fit the SGDClassifier with log loss - mini_timer = time.time() - sgd_clf = SGDClassifier(loss='log_loss', max_iter=10000, tol=1e-5, - alpha=0.0001, random_state=1, - n_jobs=min(args.nb_pts, args.nbr_processes)) - sgd_clf.fit(X=features, y=labels) - print('\tSGDClassifier fit of training data in {} seconds'.format( + svc.fit(X=scaled_streamline_data, y=labels) + logging.info('\tSVC fit of training data in {} seconds'.format( round(time.time() - mini_timer, 3))) - # Scale the coordinates of the voxels and transform with RBFSampler + # Scale the coordinates of the voxels mini_timer = time.time() voxel_coords = np.array(np.where(binary_map)).T scaled_voxel_coords = scaler.transform(voxel_coords) - transformed_voxel_coords = rbf_feature.transform(scaled_voxel_coords) # Predict the labels for the voxels - labels = sgd_clf.predict(X=transformed_voxel_coords) - print('\tSGDClassifier prediction of labels in {} seconds'.format( + labels = svc.predict(X=scaled_voxel_coords) + logging.info('\tSVC prediction of labels in {} seconds'.format( round(time.time() - mini_timer, 3))) - logging.info('Computed labels using the hyperplane method ' - 'in {} seconds'.format(round(time.time() - timer, 3))) + logging.info('Computed labels using the hyperplane method ' + 'in {} seconds'.format(round(time.time() - timer, 3))) labels_map = np.zeros(binary_map.shape, dtype=np.uint16) labels_map[np.where(binary_map)] = labels - # Correct the hyperplane labels to have a more uniform size - if args.hyperplane: + # # Correct the hyperplane labels to have a more uniform size + + timer = time.time() + tmp_sft = resample_streamlines_step_size(concat_sft, 1.0) + labels_map = correct_labels_jump(labels_map, tmp_sft.streamlines, + args.nb_pts - 2) + + if args.hyperplane and endpoints_extended: labels_map[labels_map == args.nb_pts] = args.nb_pts - 1 labels_map[labels_map == 1] = 2 labels_map[labels_map > 0] -= 1 args.nb_pts -= 2 - - timer = time.time() - labels_map = correct_labels_jump(labels_map, concat_sft.streamlines, - args.nb_pts) logging.info('Corrected labels jump in {} seconds'.format( round(time.time() - timer, 3))) From e4a373f363db7e65523efe0acc00c586de8fb480 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 2 Oct 2024 16:22:20 -0400 Subject: [PATCH 11/14] Fix a few mistakes, deleted a file --- scilpy/tractanalysis/bundle_operations.py | 3 +- scilpy/tractanalysis/distance_to_centroid.py | 109 ++++- scilpy/utils/streamlines.py | 415 ------------------- scripts/scil_bundle_label_map.py | 154 ++----- 4 files changed, 144 insertions(+), 537 deletions(-) delete mode 100644 scilpy/utils/streamlines.py diff --git a/scilpy/tractanalysis/bundle_operations.py b/scilpy/tractanalysis/bundle_operations.py index 297bb847e..de1f4fbf8 100644 --- a/scilpy/tractanalysis/bundle_operations.py +++ b/scilpy/tractanalysis/bundle_operations.py @@ -60,7 +60,7 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): old_origin = sft.origin sft.to_vox() sft.to_corner() - print('a', sft) + density = get_endpoints_density_map(sft, point_to_select=3) indices = np.argwhere(density > 0) kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True, @@ -343,4 +343,3 @@ def remove_outliers_qb(streamlines, threshold, nb_points=12, nb_samplings=30, outliers_ids, inliers_ids = prune(streamlines, threshold, summary) return outliers_ids, inliers_ids - diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index bbb929bd6..ce3de7793 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,4 +1,8 @@ # -*- coding: utf-8 -*- +from scilpy.tractograms.streamline_operations import \ + resample_streamlines_num_points, resample_streamlines_step_size +import time +import logging from scilpy.tractograms.streamline_and_mask_operations import get_head_tail_density_maps import heapq @@ -61,8 +65,8 @@ def associate_labels(target_sft, source_sft, nb_pts=20, sample_set=False, curr_ind = 0 target_labels = np.zeros(target_sft.streamlines._data.shape[0], dtype=float) - - # TODO Single prediction array + + # TODO Single prediction array for streamline in target_sft.streamlines: head_tail = [streamline[0], streamline[-1]] scaled_head_tail_data = scaler.transform(head_tail) @@ -216,14 +220,14 @@ def masked_manhattan_distance(mask, target_positions): return distances -def compute_distance_map(labels_map, binary_map, is_euclidian, nb_pts): +def compute_distance_map(labels_map, binary_mask, is_euclidian, nb_pts): """ Computes the distance map for each label in the labels_map. Parameters: labels_map (numpy.ndarray): A 3D array representing the labels map. - binary_map (numpy.ndarray): + binary_mask (numpy.ndarray): A 3D binary map used to calculate barycenter binary map. hyperplane (bool): A flag to determine the type of distance calculation. @@ -253,10 +257,10 @@ def compute_distance_map(labels_map, binary_map, is_euclidian, nb_pts): valid_data[:, i]) for i in range(tmp_barycenter.shape[1])]).T barycenters[head-1:tail] = interpolated_data - distance_map = np.zeros(binary_map.shape, dtype=float) + distance_map = np.zeros(binary_mask.shape, dtype=float) barycenter_strs = [barycenters[head-1:tail]] barycenter_bin = compute_tract_counts_map(barycenter_strs, - binary_map.shape) + binary_mask.shape) barycenter_bin[barycenter_bin > 0] = 1 for label in range(head, tail+1): @@ -280,7 +284,7 @@ def compute_distance_map(labels_map, binary_map, is_euclidian, nb_pts): distance_map[labels_map == label] = np.min(distances, axis=0) else: coords = [tuple(coord) for coord in barycenter_intersect_coords] - curr_dists = masked_manhattan_distance(binary_map, coords) + curr_dists = masked_manhattan_distance(binary_mask, coords) distance_map[labels_map == label] = \ curr_dists[labels_map == label] @@ -290,8 +294,8 @@ def compute_distance_map(labels_map, binary_map, is_euclidian, nb_pts): def correct_labels_jump(labels_map, streamlines, nb_pts): labels_data = ndi.map_coordinates(labels_map, streamlines._data.T - 0.5, order=0) - binary_map = np.zeros(labels_map.shape, dtype=np.uint8) - binary_map[labels_map > 0] = 1 + binary_mask = np.zeros(labels_map.shape, dtype=np.uint8) + binary_mask[labels_map > 0] = 1 # It is not allowed that labels jumps labels for consistency # Streamlines should have continous labels @@ -351,7 +355,7 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): labels_map = np.zeros(labels_map.shape, dtype=np.uint16) for ind in indices: - neighbor_ids = kd_tree.query_ball_point(ind, r=2.0) + neighbor_ids = kd_tree.query_ball_point(ind, r=1.0) if not len(neighbor_ids): _, neighbor_ids = kd_tree.query(ind, k=5) labels_val = np.median(final_labels._data[neighbor_ids]) @@ -366,4 +370,87 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): sigma=2) labels_map = np.argmax(tmp_labels_map, axis=-1).astype(np.uint16) - return binary_map * labels_map + return binary_mask * labels_map + + +def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts, + method='centerline'): + # This allows to have a more uniform (in size) first and last labels + endpoints_extended = False + if method == 'hyperplane' and nb_pts >= 5: + nb_pts += 2 + endpoints_extended = True + + sft_centroid = resample_streamlines_num_points(sft_centroid, nb_pts) + + timer = time.time() + if method == 'centerline': + indices = np.array(np.nonzero(binary_mask), dtype=int).T + labels = min_dist_to_centroid(indices, + sft_centroid[0].streamlines._data, + nb_pts=nb_pts) + logging.info('Computed labels using the euclidian method ' + f'in {round(time.time() - timer, 3)} seconds') + else: + logging.info('Computing Labels using the hyperplane method.\n' + '\tThis can take a while...') + # Select 2000 elements from the SFTs to train the classifier + random_indices = np.random.choice(len(sft), + min(len(sft), 2000), + replace=False) + tmp_sft = resample_streamlines_step_size(sft[random_indices], + 1.0) + # Associate the labels to the streamlines using the centroids as + # reference (to handle shorter bundles due to missing data) + mini_timer = time.time() + sample_size = np.count_nonzero(binary_mask) // nb_pts + labels, points, = associate_labels(tmp_sft, sft_centroid, + nb_pts, sample_set=True, + sample_size=sample_size) + + logging.info('\tAssociated labels to centroids in ' + f'{round(time.time() - mini_timer, 3)} seconds') + + # Initialize the scaler + mini_timer = time.time() + scaler = MinMaxScaler(feature_range=(-1, 1)) + scaler.fit(points) + scaled_streamline_data = scaler.transform(points) + + svc = SVC(C=1.0, kernel='rbf', random_state=1) + + svc.fit(X=scaled_streamline_data, y=labels) + logging.info('\tSVC fit of training data in ' + f'{round(time.time() - mini_timer, 3)} seconds') + + # Scale the coordinates of the voxels + mini_timer = time.time() + voxel_coords = np.array(np.where(binary_mask)).T + scaled_voxel_coords = scaler.transform(voxel_coords) + + # Predict the labels for the voxels + labels = svc.predict(X=scaled_voxel_coords) + logging.info('\tSVC prediction of labels in ' + f'{round(time.time() - mini_timer, 3)} seconds') + + logging.info('Computed labels using the hyperplane method ' + f'in {round(time.time() - timer, 3)} seconds') + labels_map = np.zeros(binary_mask.shape, dtype=np.uint16) + labels_map[np.where(binary_mask)] = labels + + # # Correct the hyperplane labels to have a more uniform size + + timer = time.time() + tmp_sft = resample_streamlines_step_size(sft, 1.0) + labels_map = correct_labels_jump(labels_map, tmp_sft.streamlines, + nb_pts - 2) + + if method == 'hyperplane' and endpoints_extended: + labels_map[labels_map == nb_pts] = nb_pts - 1 + labels_map[labels_map == 1] = 2 + labels_map[labels_map > 0] -= 1 + nb_pts -= 2 + logging.info('Corrected labels jump in ' + f'{round(time.time() - timer, 3)} seconds') + + return labels_map diff --git a/scilpy/utils/streamlines.py b/scilpy/utils/streamlines.py deleted file mode 100644 index 107623b85..000000000 --- a/scilpy/utils/streamlines.py +++ /dev/null @@ -1,415 +0,0 @@ -# -*- coding: utf-8 -*- -import copy -import logging - -from dipy.io.stateful_tractogram import StatefulTractogram -from dipy.tracking.streamline import set_number_of_points -from dipy.tracking.streamlinespeed import compress_streamlines -import numpy as np -from scipy.spatial import cKDTree -from sklearn.cluster import KMeans - -from scilpy.io.utils import load_matrix_in_any_format -from scilpy.tractanalysis.bundle_operations import get_streamlines_centroid -from scilpy.tractograms.streamline_and_mask_operations import \ - get_endpoints_density_map - - -def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): - """Uniformize the streamlines in the given tractogram. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be uniformized - axis: int, optional - Orient endpoints in the given axis - ref_bundle: streamlines - Orient endpoints the same way as this bundle (or centroid) - swap: boolean, optional - Swap the orientation of streamlines - """ - old_space = sft.space - old_origin = sft.origin - sft.to_vox() - sft.to_corner() - density = get_endpoints_density_map(sft, point_to_select=3) - sft.to_space(old_space) - sft.to_origin(old_origin) - indices = np.argwhere(density > 0) - kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True, - n_init=20).fit(indices) - - labels = np.zeros(density.shape) - for i in range(len(kmeans.labels_)): - labels[tuple(indices[i])] = kmeans.labels_[i]+1 - - k_means_centers = kmeans.cluster_centers_ - main_dir_barycenter = np.argmax( - np.abs(k_means_centers[0] - k_means_centers[-1])) - - if len(sft.streamlines) > 0: - axis_name = ['x', 'y', 'z'] - if axis is None or ref_bundle is not None: - if ref_bundle is not None: - centroid = get_streamlines_centroid(ref_bundle.streamlines, - 20)[0] - else: - centroid = get_streamlines_centroid(sft.streamlines, 20)[0] - endpoints_distance = np.abs(centroid[0] - centroid[-1]) - total_travel = np.sum( - np.abs(np.gradient(centroid, axis=0)), axis=0) - - # Reweigth the distance to the endpoints and the total travel - # to avoid having a XYbundle oriented in the wrong direction - endpoints_distance[-1] *= 0.9 - total_travel[-1] *= 0.9 - main_dir_ends = np.argmax(endpoints_distance) - main_dir_displacement = np.argmax(total_travel) - - if main_dir_displacement != main_dir_ends \ - or main_dir_displacement != main_dir_barycenter: - logging.info('Ambiguity in orientation, you should use --axis') - - # Get the winner - winner = np.zeros(3) - winner[main_dir_displacement] += 1 - winner[main_dir_ends] += 1 - winner[main_dir_barycenter] += 1 - axis_pos = np.argmax(winner) - axis = axis_name[axis_pos] - logging.info('Orienting endpoints in the {} axis'.format(axis)) - - if bool(k_means_centers[0][axis_pos] > - k_means_centers[1][axis_pos]) ^ bool(swap): - labels[labels == 1] = 3 - labels[labels == 2] = 1 - labels[labels == 3] = 2 - - for i in range(len(sft.streamlines)): - if ref_bundle: - res_centroid = set_number_of_points(centroid, 20) - res_streamlines = set_number_of_points(sft.streamlines[i], 20) - norm_direct = np.sum( - np.linalg.norm(res_centroid - res_streamlines, axis=0)) - norm_flip = np.sum( - np.linalg.norm(res_centroid - res_streamlines[::-1], axis=0)) - if bool(norm_direct > norm_flip) ^ bool(swap): - sft.streamlines[i] = sft.streamlines[i][::-1] - for key in sft.data_per_point[i]: - sft.data_per_point[key][i] = \ - sft.data_per_point[key][i][::-1] - else: - # Bitwise XOR - if bool(labels[tuple(sft.streamlines[i][0].astype(int))] > - labels[tuple(sft.streamlines[i][-1].astype(int))]) ^ bool(swap): - sft.streamlines[i] = sft.streamlines[i][::-1] - for key in sft.data_per_point[i]: - sft.data_per_point[key][i] = \ - sft.data_per_point[key][i][::-1] - - -def uniformize_bundle_sft_using_mask(sft, mask, swap=False): - """Uniformize the streamlines in the given tractogram so head is closer to - to a region of interest. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be uniformized - mask: np.ndarray - Mask to use as a reference for the ROI. - swap: boolean, optional - Swap the orientation of streamlines - """ - - # barycenter = np.average(np.argwhere(mask), axis=0) - old_space = sft.space - old_origin = sft.origin - sft.to_vox() - sft.to_corner() - - tree = cKDTree(np.argwhere(mask)) - for i in range(len(sft.streamlines)): - head_dist = tree.query(sft.streamlines[i][0])[0] - tail_dist = tree.query(sft.streamlines[i][-1])[0] - if bool(head_dist > tail_dist) ^ bool(swap): - sft.streamlines[i] = sft.streamlines[i][::-1] - for key in sft.data_per_point[i]: - sft.data_per_point[key][i] = \ - sft.data_per_point[key][i][::-1] - - sft.to_space(old_space) - sft.to_origin(old_origin) - - -def clip_and_normalize_data_for_cmap(args, data): - if args.LUT: - LUT = load_matrix_in_any_format(args.LUT) - for i, val in enumerate(LUT): - data[data == i+1] = val - - if args.min_range is not None or args.max_range is not None: - data = np.clip(data, args.min_range, args.max_range) - - # get data values range - if args.min_cmap is not None: - lbound = args.min_cmap - else: - lbound = np.min(data) - if args.max_cmap is not None: - ubound = args.max_cmap - else: - ubound = np.max(data) - - if args.log: - data[data > 0] = np.log10(data[data > 0]) - - # normalize data between 0 and 1 - data -= lbound - data = data / ubound if ubound > 0 else data - return data, lbound, ubound - - -def get_color_streamlines_from_angle(sft, args): - """Color streamlines according to their length. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be colored - args: NameSpace - The colormap options. - - Returns - ------- - color: np.ndarray - An array of shape (nb_streamlines, 3) containing the RGB values of - streamlines - lbound: float - Minimal value - ubound: float - Maximal value - """ - angles = [] - for i in range(len(sft.streamlines)): - dirs = np.diff(sft.streamlines[i], axis=0) - dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True) - cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1) - # Resolve numerical instability - cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0) - line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0] - angles.extend(line_angles) - - angles = np.rad2deg(angles) - - return clip_and_normalize_data_for_cmap(args, angles) - - -def get_color_streamlines_along_length(sft, args): - """Color streamlines according to their length. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be colored - args: NameSpace - The colormap options. - - Returns - ------- - color: np.ndarray - An array of shape (nb_streamlines, 3) containing the RGB values of - streamlines - lbound: int - Minimal value - ubound: int - Maximal value - """ - positions = [] - for i in range(len(sft.streamlines)): - positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i])))) - - return clip_and_normalize_data_for_cmap(args, positions) - - -def filter_tractogram_data(tractogram, streamline_ids): - """ Filter tractogram according to streamline ids and keep the data - - Parameters: - ----------- - tractogram: StatefulTractogram - Tractogram containing the data to be filtered - streamline_ids: array_like - List of streamline ids the data corresponds to - - Returns: - -------- - new_tractogram: Tractogram or StatefulTractogram - Returns a new tractogram with only the selected streamlines - and data - """ - - streamline_ids = np.asarray(streamline_ids, dtype=int) - - assert np.all( - np.in1d(streamline_ids, np.arange(len(tractogram.streamlines))) - ), "Received ids outside of streamline range" - - new_streamlines = tractogram.streamlines[streamline_ids] - new_data_per_streamline = tractogram.data_per_streamline[streamline_ids] - new_data_per_point = tractogram.data_per_point[streamline_ids] - - # Could have been nice to deepcopy the tractogram modify the attributes in - # place instead of creating a new one, but tractograms cant be subsampled - # if they have data - - return StatefulTractogram.from_sft( - new_streamlines, - tractogram, - data_per_point=new_data_per_point, - data_per_streamline=new_data_per_streamline) - - -def compress_sft(sft, tol_error=0.01): - """ Compress a stateful tractogram. Uses Dipy's compress_streamlines, but - deals with space better. - - Dipy's description: - The compression consists in merging consecutive segments that are - nearly collinear. The merging is achieved by removing the point the two - segments have in common. - - The linearization process [Presseau15]_ ensures that every point being - removed are within a certain margin (in mm) of the resulting streamline. - Recommendations for setting this margin can be found in [Presseau15]_ - (in which they called it tolerance error). - - The compression also ensures that two consecutive points won't be too far - from each other (precisely less or equal than `max_segment_length`mm). - This is a tradeoff to speed up the linearization process [Rheault15]_. A - low value will result in a faster linearization but low compression, - whereas a high value will result in a slower linearization but high - compression. - - [Presseau C. et al., A new compression format for fiber tracking datasets, - NeuroImage, no 109, 73-83, 2015.] - - Parameters - ---------- - sft: StatefulTractogram - The sft to compress. - tol_error: float (optional) - Tolerance error in mm (default: 0.01). A rule of thumb is to set it - to 0.01mm for deterministic streamlines and 0.1mm for probabilitic - streamlines. - - Returns - ------- - compressed_sft : StatefulTractogram - """ - # Go to world space - orig_space = sft.space - sft.to_rasmm() - - # Compress streamlines - compressed_streamlines = compress_streamlines(sft.streamlines, - tol_error=tol_error) - if sft.data_per_point is not None and sft.data_per_point.keys(): - logging.warning("Initial StatefulTractogram contained data_per_point. " - "This information will not be carried in the final " - "tractogram.") - - compressed_sft = StatefulTractogram.from_sft( - compressed_streamlines, sft, - data_per_streamline=sft.data_per_streamline) - - # Return to original space - compressed_sft.to_space(orig_space) - - return compressed_sft - - -def cut_invalid_streamlines(sft): - """ Cut streamlines so their longest segment are within the bounding box. - This function keeps the data_per_point and data_per_streamline. - - Parameters - ---------- - sft: StatefulTractogram - The sft to remove invalid points from. - - Returns - ------- - new_sft : StatefulTractogram - New object with the invalid points removed from each streamline. - cutting_counter : int - Number of streamlines that were cut. - """ - if not len(sft): - return sft, 0 - - # Keep track of the streamlines' original space/origin - space = sft.space - origin = sft.origin - - sft.to_vox() - sft.to_corner() - - copy_sft = copy.deepcopy(sft) - epsilon = 0.001 - indices_to_remove, _ = copy_sft.remove_invalid_streamlines() - - new_streamlines = [] - new_data_per_point = {} - new_data_per_streamline = {} - for key in sft.data_per_point.keys(): - new_data_per_point[key] = [] - for key in sft.data_per_streamline.keys(): - new_data_per_streamline[key] = [] - - cutting_counter = 0 - for ind in range(len(sft.streamlines)): - # No reason to try to cut if all points are within the volume - if ind in indices_to_remove: - best_pos = [0, 0] - cur_pos = [0, 0] - for pos, point in enumerate(sft.streamlines[ind]): - if (point < epsilon).any() or \ - (point >= sft.dimensions - epsilon).any(): - cur_pos = [pos+1, pos+1] - if cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]: - best_pos = cur_pos - cur_pos[1] += 1 - - if not best_pos == [0, 0]: - new_streamlines.append( - sft.streamlines[ind][best_pos[0]:best_pos[1]-1]) - cutting_counter += 1 - for key in sft.data_per_streamline.keys(): - new_data_per_streamline[key].append( - sft.data_per_streamline[key][ind]) - for key in sft.data_per_point.keys(): - new_data_per_point[key].append( - sft.data_per_point[key][ind][best_pos[0]:best_pos[1]-1]) - else: - logging.warning('Streamlines entirely out of the volume.') - else: - new_streamlines.append(sft.streamlines[ind]) - for key in sft.data_per_streamline.keys(): - new_data_per_streamline[key].append( - sft.data_per_streamline[key][ind]) - for key in sft.data_per_point.keys(): - new_data_per_point[key].append(sft.data_per_point[key][ind]) - new_sft = StatefulTractogram.from_sft(new_streamlines, sft, - data_per_streamline=new_data_per_streamline, - data_per_point=new_data_per_point) - - # Move the streamlines back to the original space/origin - sft.to_space(space) - sft.to_origin(origin) - - new_sft.to_space(space) - new_sft.to_origin(origin) - - return new_sft, cutting_counter diff --git a/scripts/scil_bundle_label_map.py b/scripts/scil_bundle_label_map.py index 45488ae86..6db361611 100755 --- a/scripts/scil_bundle_label_map.py +++ b/scripts/scil_bundle_label_map.py @@ -58,8 +58,6 @@ from nibabel.streamlines.array_sequence import ArraySequence import numpy as np import scipy.ndimage as ndi -from sklearn.preprocessing import MinMaxScaler -from sklearn.svm import SVC from scilpy.image.volume_math import neighborhood_correlation_ from scilpy.io.streamlines import load_tractogram_with_reference @@ -70,15 +68,10 @@ assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -from scilpy.tractanalysis.distance_to_centroid import (min_dist_to_centroid, - compute_distance_map, - associate_labels, - correct_labels_jump) +from scilpy.tractanalysis.distance_to_centroid import (subdivide_bundles, + compute_distance_map) from scilpy.tractograms.streamline_and_mask_operations import \ cut_streamlines_with_mask -from scilpy.tractograms.streamline_operations import \ - resample_streamlines_num_points, resample_streamlines_step_size -from scilpy.utils.streamlines import uniformize_bundle_sft from scilpy.viz.color import get_lookup_table @@ -124,12 +117,13 @@ def _build_arg_parser(): 'distance.') p.add_argument('--skip_uniformize', action='store_true', help='Skip uniformization of the bundles orientation.') - p.add_argument('--correlation_thr', type=float, default=0.1, + p.add_argument('--correlation_thr', type=float, const=0.1, nargs='?', + default=0, help='Threshold for the correlation map. Only for multi ' 'bundle case. [%(default)s]') - p.add_argument('--streamlines_thr', type=int, default=2, + p.add_argument('--streamlines_thr', type=int, const=2, nargs='?', help='Threshold for the minimum number of streamlines in a ' - 'voxel to be included [%(default)s].') + 'voxel to be included [%(default)s].') add_reference_arg(p) add_verbose_arg(p) @@ -153,6 +147,9 @@ def main(): if args.verbose: logging.getLogger().setLevel(logging.INFO) + # TODO check if correlation thr is positive + # TODO check if streamlines thr is positive (above 1) + # When doing longitudinal data, the split in subsection can be done # on all the bundles at once. Needs to be co-registered. timer = time.time() @@ -184,7 +181,10 @@ def main(): density = compute_tract_counts_map(sft.streamlines, sft.dimensions).astype(float) binary = np.zeros(sft.dimensions, dtype=np.uint8) - binary[density >= args.streamlines_thr] = 1 + if args.streamlines_thr is not None: + binary[density >= args.streamlines_thr] = 1 + else: + binary[density > 0] = 1 binary_list.append(binary) density_list.append(density) @@ -196,23 +196,28 @@ def main(): f'{round(time.time() - timer, 3)}.') if len(density_list) > 1: + timer = time.time() corr_map = neighborhood_correlation_(density_list) + logging.info(f'Computed correlation map in ' + f'{round(time.time() - timer, 3)} seconds') else: corr_map = density_list[0].astype(float) corr_map[corr_map > 0] = 1 # Slightly cut the bundle at the edge to clean up single streamline voxels # with no neighbor. Remove isolated voxels to keep a single 'blob' - binary_map = np.max(binary_list, axis=0) - binary_map[corr_map < args.correlation_thr] = 0 + binary_mask = np.max(binary_list, axis=0) - # TODO eliminate the bottom quartile of the blob - bundle_disjoint, _ = ndi.label(binary_map) - unique, count = np.unique(bundle_disjoint, return_counts=True) - val = unique[np.argmax(count[1:])+1] - binary_map[bundle_disjoint != val] = 0 + if args.correlation_thr > 1e-3: + binary_mask[corr_map < args.correlation_thr] = 0 - nib.save(nib.Nifti1Image(corr_map * binary_map, sft_list[0].affine), + if args.streamlines_thr is not None: + bundle_disjoint, _ = ndi.label(binary_mask) + unique, count = np.unique(bundle_disjoint, return_counts=True) + val = unique[np.argmax(count[1:])+1] + binary_mask[bundle_disjoint != val] = 0 + + nib.save(nib.Nifti1Image(corr_map * binary_mask, sft_list[0].affine), os.path.join(args.out_dir, 'correlation_map.nii.gz')) # A bundle must be contiguous, we can't have isolated voxels. @@ -221,99 +226,30 @@ def main(): concat_sft.to_vox() concat_sft.to_corner() for i in range(len(sft_list)): - sft_list[i] = cut_streamlines_with_mask(sft_list[i], - binary_map) + if args.streamlines_thr is not None: + sft_list[i] = cut_streamlines_with_mask(sft_list[i], + binary_mask) + else: + sft_list[i].data_per_streamline = {} + sft_list[i].data_per_point = {} + if len(sft_list[i]): concat_sft += sft_list[i] logging.info(f'Chop bundle(s) in {round(time.time() - timer, 3)} seconds.') + # Here + method = 'hyperplane' if args.hyperplane else 'centerline' args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \ else args.nb_pts - - # This allows to have a more uniform (in size) first and last labels - endpoints_extended = False - if args.hyperplane and args.nb_pts >= 5: - args.nb_pts += 2 - endpoints_extended = True - - sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) + labels_map = subdivide_bundles(concat_sft, sft_centroid, binary_mask, + args.nb_pts, method=method) timer = time.time() - if not args.hyperplane: - indices = np.array(np.nonzero(binary_map), dtype=int).T - labels = min_dist_to_centroid(indices, - sft_centroid[0].streamlines._data, - nb_pts=args.nb_pts) - logging.info('Computed labels using the euclidian method ' - f'in {round(time.time() - timer, 3)} seconds') - else: - logging.info('Computing Labels using the hyperplane method.\n' - '\tThis can take a while...') - # Select 2000 elements from the SFTs to train the classifier - random_indices = np.random.choice(len(concat_sft), - min(len(concat_sft), 2000), - replace=False) - tmp_sft = resample_streamlines_step_size(concat_sft[random_indices], - 1.0) - # Associate the labels to the streamlines using the centroids as - # reference (to handle shorter bundles due to missing data) - mini_timer = time.time() - sample_size = np.count_nonzero(binary_map) // args.nb_pts - labels, points, = associate_labels(tmp_sft, sft_centroid, - args.nb_pts, sample_set=True, - sample_size=sample_size) - - logging.info('\tAssociated labels to centroids in ' - f'{round(time.time() - mini_timer, 3)} seconds') - - # Initialize the scaler - mini_timer = time.time() - scaler = MinMaxScaler(feature_range=(-1, 1)) - scaler.fit(points) - scaled_streamline_data = scaler.transform(points) - - svc = SVC(C=1.0, kernel='rbf', random_state=1) - - svc.fit(X=scaled_streamline_data, y=labels) - logging.info('\tSVC fit of training data in ' - f'{round(time.time() - mini_timer, 3)} seconds') - - # Scale the coordinates of the voxels - mini_timer = time.time() - voxel_coords = np.array(np.where(binary_map)).T - scaled_voxel_coords = scaler.transform(voxel_coords) - - # Predict the labels for the voxels - labels = svc.predict(X=scaled_voxel_coords) - logging.info('\tSVC prediction of labels in ' - f'{round(time.time() - mini_timer, 3)} seconds') - - logging.info('Computed labels using the hyperplane method ' - f'in {round(time.time() - timer, 3)} seconds') - labels_map = np.zeros(binary_map.shape, dtype=np.uint16) - labels_map[np.where(binary_map)] = labels - - # # Correct the hyperplane labels to have a more uniform size - - timer = time.time() - tmp_sft = resample_streamlines_step_size(concat_sft, 1.0) - labels_map = correct_labels_jump(labels_map, tmp_sft.streamlines, - args.nb_pts - 2) - - if args.hyperplane and endpoints_extended: - labels_map[labels_map == args.nb_pts] = args.nb_pts - 1 - labels_map[labels_map == 1] = 2 - labels_map[labels_map > 0] -= 1 - args.nb_pts -= 2 - logging.info('Corrected labels jump in ' - f'{round(time.time() - timer, 3)} seconds') - - timer = time.time() - distance_map = compute_distance_map(labels_map, binary_map, + distance_map = compute_distance_map(labels_map, binary_mask, args.use_manhattan, args.nb_pts) logging.info('Computed distance map in ' f'{round(time.time() - timer, 3)} seconds') - + print(binary_mask.shape, np.count_nonzero(binary_mask)) timer = time.time() cmap = get_lookup_table(args.colormap) for i, sft in enumerate(sft_list): @@ -341,17 +277,17 @@ def main(): continue if len(sft): - tmp_data=ndi.map_coordinates( + tmp_data = ndi.map_coordinates( map, sft.streamlines._data.T - 0.5, order=0) if basename == 'labels': - max_val=args.nb_pts + max_val = args.nb_pts elif basename == 'correlation': - max_val=1 + max_val = 1 else: - max_val=np.max(tmp_data) - max_val=args.nb_pts - new_sft.data_per_point['color']._data=cmap( + max_val = np.max(tmp_data) + max_val = args.nb_pts + new_sft.data_per_point['color']._data = cmap( tmp_data / max_val)[:, 0:3] * 255 # Save the tractogram From 1616219de9da9a4498895a3503804ec6b00c68a6 Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 8 Oct 2024 08:23:31 -0400 Subject: [PATCH 12/14] TODOs --- scilpy/tractanalysis/distance_to_centroid.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index ce3de7793..ebe3693f4 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -43,6 +43,7 @@ def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts): def associate_labels(target_sft, source_sft, nb_pts=20, sample_set=False, sample_size=None): + # DOCSTRING curr_ind = 0 source_labels = np.zeros(source_sft.streamlines._data.shape[0], dtype=float) @@ -375,6 +376,7 @@ def correct_labels_jump(labels_map, streamlines, nb_pts): def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts, method='centerline'): + # TODO DOCSTRING ! # This allows to have a more uniform (in size) first and last labels endpoints_extended = False if method == 'hyperplane' and nb_pts >= 5: From c915ac4b2c38765b6b72c527fc2235e768454bdb Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 18 Oct 2024 15:21:30 -0400 Subject: [PATCH 13/14] Change test file --- scripts/tests/test_bundle_label_map.py | 31 ++++++++++--- .../test_compute_bundle_voxel_label_map.py | 44 ------------------- 2 files changed, 24 insertions(+), 51 deletions(-) delete mode 100644 scripts/tests/test_compute_bundle_voxel_label_map.py diff --git a/scripts/tests/test_bundle_label_map.py b/scripts/tests/test_bundle_label_map.py index a2e228a35..00a6b8bc2 100644 --- a/scripts/tests/test_bundle_label_map.py +++ b/scripts/tests/test_bundle_label_map.py @@ -5,7 +5,8 @@ import tempfile from scilpy import SCILPY_HOME -from scilpy.io.fetcher import fetch_data, get_testing_files_dict +from scilpy.io.fetcher import get_testing_files_dict, fetch_data + # If they already exist, this only takes 5 seconds (check md5sum) fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) @@ -13,15 +14,31 @@ def test_help_option(script_runner): - ret = script_runner.run('scil_bundle_label_map.py', '--help') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', '--help') assert ret.success -def test_execution_tractometry(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM.trk') +def test_execution_tractometry_euclidian(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM.trk') + in_centroid = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM_uni_c_10.trk') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', + in_bundle, in_centroid, + 'results_euc/', + '--colormap', 'viridis') + assert ret.success + +def test_execution_tractometry_hyperplane(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM.trk') in_centroid = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM_uni_c_10.trk') - ret = script_runner.run('scil_bundle_label_map.py', in_bundle, in_centroid, - 'results_dir/', '--colormap', 'viridis') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', + in_bundle, in_centroid, + 'results_man/', + '--colormap', 'viridis', + '--hyperplane', '--use_manhattan') assert ret.success diff --git a/scripts/tests/test_compute_bundle_voxel_label_map.py b/scripts/tests/test_compute_bundle_voxel_label_map.py deleted file mode 100644 index 00a6b8bc2..000000000 --- a/scripts/tests/test_compute_bundle_voxel_label_map.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import tempfile - -from scilpy import SCILPY_HOME -from scilpy.io.fetcher import get_testing_files_dict, fetch_data - - -# If they already exist, this only takes 5 seconds (check md5sum) -fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) -tmp_dir = tempfile.TemporaryDirectory() - - -def test_help_option(script_runner): - ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', '--help') - assert ret.success - - -def test_execution_tractometry_euclidian(script_runner): - os.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'tractometry', - 'IFGWM.trk') - in_centroid = os.path.join(SCILPY_HOME, 'tractometry', - 'IFGWM_uni_c_10.trk') - ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', - in_bundle, in_centroid, - 'results_euc/', - '--colormap', 'viridis') - assert ret.success - -def test_execution_tractometry_hyperplane(script_runner): - os.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'tractometry', - 'IFGWM.trk') - in_centroid = os.path.join(SCILPY_HOME, 'tractometry', - 'IFGWM_uni_c_10.trk') - ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', - in_bundle, in_centroid, - 'results_man/', - '--colormap', 'viridis', - '--hyperplane', '--use_manhattan') - assert ret.success From 529646a7bf2c00eb1804edd7f28de59186586f2a Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 18 Oct 2024 15:22:46 -0400 Subject: [PATCH 14/14] pep8 --- scripts/tests/test_bundle_label_map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/tests/test_bundle_label_map.py b/scripts/tests/test_bundle_label_map.py index 00a6b8bc2..401760f33 100644 --- a/scripts/tests/test_bundle_label_map.py +++ b/scripts/tests/test_bundle_label_map.py @@ -30,6 +30,7 @@ def test_execution_tractometry_euclidian(script_runner): '--colormap', 'viridis') assert ret.success + def test_execution_tractometry_hyperplane(script_runner): os.chdir(os.path.expanduser(tmp_dir.name)) in_bundle = os.path.join(SCILPY_HOME, 'tractometry',