diff --git a/scilpy/tractanalysis/bundle_operations.py b/scilpy/tractanalysis/bundle_operations.py index 033bbd73b..de1f4fbf8 100644 --- a/scilpy/tractanalysis/bundle_operations.py +++ b/scilpy/tractanalysis/bundle_operations.py @@ -60,6 +60,7 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): old_origin = sft.origin sft.to_vox() sft.to_corner() + density = get_endpoints_density_map(sft, point_to_select=3) indices = np.argwhere(density > 0) kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True, @@ -342,4 +343,3 @@ def remove_outliers_qb(streamlines, threshold, nb_points=12, nb_samplings=30, outliers_ids, inliers_ids = prune(streamlines, threshold, summary) return outliers_ids, inliers_ids - diff --git a/scilpy/tractanalysis/distance_to_centroid.py b/scilpy/tractanalysis/distance_to_centroid.py index a4c4022a1..ebe3693f4 100644 --- a/scilpy/tractanalysis/distance_to_centroid.py +++ b/scilpy/tractanalysis/distance_to_centroid.py @@ -1,7 +1,23 @@ # -*- coding: utf-8 -*- +from scilpy.tractograms.streamline_operations import \ + resample_streamlines_num_points, resample_streamlines_step_size +import time +import logging +from scilpy.tractograms.streamline_and_mask_operations import get_head_tail_density_maps +import heapq +from dipy.tracking.metrics import length +from nibabel.streamlines.array_sequence import ArraySequence import numpy as np +import scipy.ndimage as ndi from scipy.spatial import KDTree +from scipy.spatial.distance import pdist, squareform + +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from sklearn.preprocessing import MinMaxScaler +from sklearn.linear_model import SGDClassifier +from sklearn.svm import SVC +from sklearn.kernel_approximation import RBFSampler, Nystroem def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts): @@ -19,20 +35,424 @@ def min_dist_to_centroid(bundle_pts, centroid_pts, nb_pts): Array: """ tree = KDTree(centroid_pts, copy_data=True) - dists, labels = tree.query(bundle_pts, k=1) - dists, labels = np.expand_dims( - dists, axis=1), np.expand_dims(labels, axis=1) + _, labels = tree.query(bundle_pts, k=1) + labels = np.mod(labels, nb_pts) + 1 + + return labels.astype(np.uint16) + + +def associate_labels(target_sft, source_sft, nb_pts=20, sample_set=False, + sample_size=None): + # DOCSTRING + curr_ind = 0 + source_labels = np.zeros(source_sft.streamlines._data.shape[0], + dtype=float) + for streamline in source_sft.streamlines: + curr_length = np.insert(length(streamline, along=True), 0, 0) + curr_labels = np.interp(curr_length, + [0, curr_length[-1]], + [1, nb_pts]) + curr_labels = np.round(curr_labels) + source_labels[curr_ind:curr_ind+len(streamline)] = curr_labels + curr_ind += len(streamline) + + scaler = MinMaxScaler(feature_range=(-1, 1)) + scaler.fit(target_sft.streamlines._data) + scaled_streamline_data = scaler.transform(source_sft.streamlines._data) + + svc = SVC(C=1.0, kernel='rbf', random_state=1) + svc.fit(X=scaled_streamline_data, y=source_labels) + + curr_ind = 0 + target_labels = np.zeros(target_sft.streamlines._data.shape[0], + dtype=float) + + # TODO Single prediction array + for streamline in target_sft.streamlines: + head_tail = [streamline[0], streamline[-1]] + scaled_head_tail_data = scaler.transform(head_tail) + head_tail_labels = svc.predict(X=scaled_head_tail_data) + curr_length = np.insert(length(streamline, along=True), 0, 0) + curr_labels = np.interp(curr_length, + [0, curr_length[-1]], + head_tail_labels) + target_labels[curr_ind:curr_ind+len(streamline)] = curr_labels + curr_ind += len(streamline) + + target_labels = np.round(target_labels).astype(int) + + if sample_set: + if sample_size is None: + sample_size = np.unique(target_labels, return_counts=True)[1].min() + + # Sort points by labels + sorted_indices = target_labels.argsort() + sorted_points = target_sft.streamlines._data[sorted_indices] + sorted_labels = target_labels[sorted_indices] + + # Find the start and end of each label + unique_labels, start_indices = np.unique( + sorted_labels, return_index=True) + end_indices = np.roll(start_indices, -1) + end_indices[-1] = len(target_labels) + + # Sample points and labels for each label + sampled_points = [] + sampled_labels = [] + for start, end, label in zip(start_indices, end_indices, unique_labels): + num_points = end - start + indices_to_sample = min(num_points, sample_size) + sampled_indices = np.random.choice( + np.arange(start, end), size=indices_to_sample, replace=False) + sampled_points.append(sorted_points[sampled_indices]) + sampled_labels.extend([label] * indices_to_sample) + + # Concatenate all sampled points + sampled_points = np.concatenate(sampled_points) + sampled_labels = np.array(sampled_labels) + + return sampled_labels, sampled_points + else: + return target_labels, target_sft.streamlines._data + + +def find_medoid(points, max_points=10000): + """ + Find the medoid among a set of points. + + Parameters: + points (ndarray): Points in N-dimensional space. + + Returns: + ndarray: Coordinates of the medoid. + """ + if len(points) > max_points: + selected_indices = np.random.choice(len(points), max_points, + replace=False) + points = points[selected_indices] + + distance_matrix = squareform(pdist(points)) + medoid_idx = np.argmin(distance_matrix.sum(axis=1)) + return points[medoid_idx] + + +def compute_labels_map_barycenters(labels_map, is_euclidian=False, nb_pts=False): + """ + Compute the barycenter for each label in a 3D NumPy array by maximizing + the distance to the boundary. + + Parameters: + labels_map (ndarray): The 3D array containing labels from 1-nb_pts. + euclidian (bool): If True, the barycenter is the mean of the points + + Returns: + ndarray: An array of size (nb_pts, 3) containing the barycenter + for each label. + """ + labels = np.arange(1, nb_pts+1) if nb_pts else np.unique(labels_map)[1:] + barycenters = np.zeros((len(labels), 3)) + barycenters[:] = np.NAN + + for label in labels: + indices = np.argwhere(labels_map == label) + if indices.size > 0: + mask = np.zeros_like(labels_map) + mask[labels_map == label] = 1 + mask_coords = np.argwhere(mask) + if is_euclidian: + barycenter = np.mean(mask_coords, axis=0) + else: + barycenter = find_medoid(mask_coords) + # If the barycenter is not in the mask, find the closest point + if labels_map[tuple(barycenter.astype(int))] != label: + tree = KDTree(indices) + _, ind = tree.query(barycenter, k=1) + del tree + barycenter = indices[ind] + + barycenters[label - 1] = barycenter + + return np.array(barycenters) + + +def masked_manhattan_distance(mask, target_positions): + """ + Compute the Manhattan distance from every position in a mask to a set of + positions, without stepping out of the mask. + + Parameters: + mask (ndarray): A binary 3D array representing the mask. + target_positions (list): A list of target positions within the mask. + + Returns: + ndarray: A 3D array of the same shape as the mask, containing the + Manhattan distances. + """ + # Initialize distance array with infinite values + distances = np.full(mask.shape, np.inf) + + # Initialize priority queue and set distance for target positions to zero + priority_queue = [] + for x, y, z in target_positions: + heapq.heappush(priority_queue, (0, (x, y, z))) + distances[x, y, z] = 0 + + # Directions for moving in the grid (Manhattan distance) + directions = [(0, 0, 1), (0, 0, -1), (0, 1, 0), + (0, -1, 0), (1, 0, 0), (-1, 0, 0)] + + while priority_queue: + current_distance, (x, y, z) = heapq.heappop(priority_queue) + + for dx, dy, dz in directions: + nx, ny, nz = x + dx, y + dy, z + dz + + if 0 <= nx < mask.shape[0] and \ + 0 <= ny < mask.shape[1] and \ + 0 <= nz < mask.shape[2]: + if mask[nx, ny, nz]: + new_distance = current_distance + 1 + + if new_distance < distances[nx, ny, nz]: + distances[nx, ny, nz] = new_distance + heapq.heappush( + priority_queue, (new_distance, (nx, ny, nz))) + + return distances + + +def compute_distance_map(labels_map, binary_mask, is_euclidian, nb_pts): + """ + Computes the distance map for each label in the labels_map. + + Parameters: + labels_map (numpy.ndarray): + A 3D array representing the labels map. + binary_mask (numpy.ndarray): + A 3D binary map used to calculate barycenter binary map. + hyperplane (bool): + A flag to determine the type of distance calculation. + nb_pts (int): + Number of points to use for computing barycenters. + + Returns: + numpy.ndarray: A 3D array representing the distance map. + """ + barycenters = compute_labels_map_barycenters(labels_map, + is_euclidian=is_euclidian, + nb_pts=nb_pts) + # If the first/last few points are NaN, remove them this indicates that the + # head/tail are not 1-NB_PTS + isnan = np.isnan(barycenters).all(axis=1) + head = np.argmax(~isnan) + 1 + tail = len(isnan) - np.argmax(~isnan[::-1]) + + # Identify the indices that do contain NaN values after/before head/tail + tmp_barycenter = barycenters[head-1:tail] + valid_indices = np.argwhere( + ~np.isnan(tmp_barycenter).any(axis=1)).flatten() + valid_data = tmp_barycenter[valid_indices] + interpolated_data = np.array( + [np.interp(np.arange(len(tmp_barycenter)), + valid_indices, + valid_data[:, i]) for i in range(tmp_barycenter.shape[1])]).T + barycenters[head-1:tail] = interpolated_data + + distance_map = np.zeros(binary_mask.shape, dtype=float) + barycenter_strs = [barycenters[head-1:tail]] + barycenter_bin = compute_tract_counts_map(barycenter_strs, + binary_mask.shape) + barycenter_bin[barycenter_bin > 0] = 1 + + for label in range(head, tail+1): + mask = np.zeros(labels_map.shape) + mask[labels_map == label] = 1 + labels_coords = np.array(np.where(mask)).T + if labels_coords.size == 0: + continue + + barycenter_bin_intersect = barycenter_bin * mask + barycenter_intersect_coords = np.array(np.nonzero(barycenter_bin_intersect), + dtype=int).T + + if barycenter_intersect_coords.size == 0: + continue + + if is_euclidian: + distances = np.linalg.norm( + barycenter_intersect_coords[:, np.newaxis] - labels_coords, + axis=-1) + distance_map[labels_map == label] = np.min(distances, axis=0) + else: + coords = [tuple(coord) for coord in barycenter_intersect_coords] + curr_dists = masked_manhattan_distance(binary_mask, coords) + distance_map[labels_map == label] = \ + curr_dists[labels_map == label] + + return distance_map + + +def correct_labels_jump(labels_map, streamlines, nb_pts): + labels_data = ndi.map_coordinates(labels_map, streamlines._data.T - 0.5, + order=0) + binary_mask = np.zeros(labels_map.shape, dtype=np.uint8) + binary_mask[labels_map > 0] = 1 + + # It is not allowed that labels jumps labels for consistency + # Streamlines should have continous labels + final_streamlines = [] + final_labels = [] + curr_ind = 0 + for streamline in streamlines: + next_ind = curr_ind + len(streamline) + curr_labels = labels_data[curr_ind:next_ind].astype(int) + curr_ind = next_ind + + # Flip streamlines so the labels increase (facilitate if/else) + # Should always be ordered in nextflow pipeline + gradient = np.ediff1d(curr_labels) + + is_flip = False + if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): + streamline = streamline[::-1] + curr_labels = curr_labels[::-1] + is_flip = True + + # Find jumps, cut them and find the longest + gradient = np.ediff1d(curr_labels) + max_jump = max(nb_pts // 2, 1) + if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: + pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 + split_chunk = np.split(curr_labels, + pos_jump) + + max_len = 0 + max_pos = 0 + for j, chunk in enumerate(split_chunk): + if len(chunk) > max_len: + max_len = len(chunk) + max_pos = j + + curr_labels = split_chunk[max_pos] + gradient_chunk = np.ediff1d(chunk) + if len(np.unique(np.sign(gradient_chunk))) > 1: + continue + streamline = np.split(streamline, + pos_jump)[max_pos] + + if is_flip: + streamline = streamline[::-1] + curr_labels = curr_labels[::-1] + final_streamlines.append(streamline) + final_labels.append(curr_labels) + + # Once the streamlines abnormalities are corrected, we can + # recompute the labels map with the new streamlines/labels + final_labels = ArraySequence(final_labels) + final_streamlines = ArraySequence(final_streamlines) + + kd_tree = KDTree(final_streamlines._data) + indices = np.array(np.nonzero(labels_map), dtype=int).T + labels_map = np.zeros(labels_map.shape, dtype=np.uint16) + + for ind in indices: + neighbor_ids = kd_tree.query_ball_point(ind, r=1.0) + if not len(neighbor_ids): + _, neighbor_ids = kd_tree.query(ind, k=5) + labels_val = np.median(final_labels._data[neighbor_ids]) + + labels_map[ind[0], ind[1], ind[2]] = labels_val + # return labels_map + # To ensure spatial smoothness, we apply a gaussian filter on the labels + tmp_labels_map = np.zeros(labels_map.shape+(nb_pts+3,), dtype=float) + for i in np.unique(labels_map)[1:]: + tmp_labels_map[labels_map == i, i] = 1 + tmp_labels_map[..., i] = ndi.gaussian_filter(tmp_labels_map[..., i], + sigma=2) + labels_map = np.argmax(tmp_labels_map, axis=-1).astype(np.uint16) + + return binary_mask * labels_map + + +def subdivide_bundles(sft, sft_centroid, binary_mask, nb_pts, + method='centerline'): + # TODO DOCSTRING ! + # This allows to have a more uniform (in size) first and last labels + endpoints_extended = False + if method == 'hyperplane' and nb_pts >= 5: + nb_pts += 2 + endpoints_extended = True + + sft_centroid = resample_streamlines_num_points(sft_centroid, nb_pts) + + timer = time.time() + if method == 'centerline': + indices = np.array(np.nonzero(binary_mask), dtype=int).T + labels = min_dist_to_centroid(indices, + sft_centroid[0].streamlines._data, + nb_pts=nb_pts) + logging.info('Computed labels using the euclidian method ' + f'in {round(time.time() - timer, 3)} seconds') + else: + logging.info('Computing Labels using the hyperplane method.\n' + '\tThis can take a while...') + # Select 2000 elements from the SFTs to train the classifier + random_indices = np.random.choice(len(sft), + min(len(sft), 2000), + replace=False) + tmp_sft = resample_streamlines_step_size(sft[random_indices], + 1.0) + # Associate the labels to the streamlines using the centroids as + # reference (to handle shorter bundles due to missing data) + mini_timer = time.time() + sample_size = np.count_nonzero(binary_mask) // nb_pts + labels, points, = associate_labels(tmp_sft, sft_centroid, + nb_pts, sample_set=True, + sample_size=sample_size) + + logging.info('\tAssociated labels to centroids in ' + f'{round(time.time() - mini_timer, 3)} seconds') + + # Initialize the scaler + mini_timer = time.time() + scaler = MinMaxScaler(feature_range=(-1, 1)) + scaler.fit(points) + scaled_streamline_data = scaler.transform(points) + + svc = SVC(C=1.0, kernel='rbf', random_state=1) + + svc.fit(X=scaled_streamline_data, y=labels) + logging.info('\tSVC fit of training data in ' + f'{round(time.time() - mini_timer, 3)} seconds') + + # Scale the coordinates of the voxels + mini_timer = time.time() + voxel_coords = np.array(np.where(binary_mask)).T + scaled_voxel_coords = scaler.transform(voxel_coords) + + # Predict the labels for the voxels + labels = svc.predict(X=scaled_voxel_coords) + logging.info('\tSVC prediction of labels in ' + f'{round(time.time() - mini_timer, 3)} seconds') + + logging.info('Computed labels using the hyperplane method ' + f'in {round(time.time() - timer, 3)} seconds') + labels_map = np.zeros(binary_mask.shape, dtype=np.uint16) + labels_map[np.where(binary_mask)] = labels - labels = np.mod(labels, nb_pts) + # # Correct the hyperplane labels to have a more uniform size - sum_dist = np.expand_dims(np.sum(dists, axis=1), axis=1) - weights = np.exp(-dists / sum_dist) + timer = time.time() + tmp_sft = resample_streamlines_step_size(sft, 1.0) + labels_map = correct_labels_jump(labels_map, tmp_sft.streamlines, + nb_pts - 2) - votes = [] - for i in range(len(bundle_pts)): - vote = np.bincount(labels[i], weights=weights[i]) - total = np.arange(np.amax(labels[i])+1) - winner = total[np.argmax(vote)] - votes.append(winner) + if method == 'hyperplane' and endpoints_extended: + labels_map[labels_map == nb_pts] = nb_pts - 1 + labels_map[labels_map == 1] = 2 + labels_map[labels_map > 0] -= 1 + nb_pts -= 2 + logging.info('Corrected labels jump in ' + f'{round(time.time() - timer, 3)} seconds') - return np.array(votes, dtype=np.uint16), np.average(dists, axis=1) + return labels_map diff --git a/scripts/scil_bundle_label_map.py b/scripts/scil_bundle_label_map.py index 9281a5c92..6db361611 100755 --- a/scripts/scil_bundle_label_map.py +++ b/scripts/scil_bundle_label_map.py @@ -2,32 +2,62 @@ # -*- coding: utf-8 -*- """ -Compute the label image (Nifti) from a centroid and tractograms (all -representing the same bundle). The label image represents the coverage of -the bundle, segmented into regions labelled from 0 to --nb_pts, starting from -the head, ending in the tail. - -Each voxel will have the label of its nearest centroid point. - -The number of labels will be the same as the centroid's number of points. - -Formerly: scil_compute_bundle_voxel_label_map.py +Compute label image (Nifti) from bundle(s) and centroid(s). +Each voxel will have a label that represents its position along the bundle. + +The number of labels will be the same as the centroid's number of points, +unless specified otherwise. + +# Single bundle case + This script takes as input a bundle file, a centroid streamline corresponding + to the bundle. It computes label images, where each voxel is assigned the + label of its nearest centroid point. The resulting images represent the + labels, distances between the bundle and centroid. + +# Multiple bundle case + When providing multiple (co-registered) bundles, the script will compute a + correlation map, which shows the spatial correlation between density maps + It will also compute the labels maps for all bundles at once, ensuring + that the labels are spatial consistent between bundles. + +# Hyperplane method + The default is to use the euclidian/centerline method, which is fast and + works well for most cases. + The hyperplane method allows for more complex shapes and to split the bundles + into subsection that follow the geometry of each kind of bundle. + However, this method is slower and requires extra quality control to ensure + that the labels are correct. This method requires a centroid file that + contains multiple streamlines. + This method is based on the following paper [1], but was heavily modified + and adapted to work more robustly across datasets. + +# Manhatan distance + The default distance (to barycenter of label) is the euclidian distance. + The manhattan distance can be used instead to compute the distance to the + barycenter without stepping out of the mask. + +Colormap selection affects tractograms coloring for visualization only. +For detailed information on usage and parameters, please refer to the script's +documentation. + +Author: +------- +Francois Rheault +francois.m.rheault@usherbrooke.ca """ import argparse import logging import os +import time -from dipy.align.streamlinear import StreamlineLinearRegistration from dipy.io.streamline import save_tractogram from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.utils import is_header_compatible -import matplotlib.pyplot as plt import nibabel as nib from nibabel.streamlines.array_sequence import ArraySequence import numpy as np import scipy.ndimage as ndi -from scipy.spatial import cKDTree from scilpy.image.volume_math import neighborhood_correlation_ from scilpy.io.streamlines import load_tractogram_with_reference @@ -38,17 +68,24 @@ assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -from scilpy.tractanalysis.distance_to_centroid import min_dist_to_centroid +from scilpy.tractanalysis.distance_to_centroid import (subdivide_bundles, + compute_distance_map) from scilpy.tractograms.streamline_and_mask_operations import \ cut_streamlines_with_mask -from scilpy.tractograms.streamline_operations import \ - resample_streamlines_num_points from scilpy.viz.color import get_lookup_table +EPILOG = """ +[1] Neher, Peter, Dusan Hirjak, and Klaus Maier-Hein. "Radiomic tractometry: a + rich and tract-specific class of imaging biomarkers for neuroscience and + medical applications." Research Square (2023). +""" + + def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, + epilog=EPILOG, formatter_class=argparse.RawTextHelpFormatter) p.add_argument('in_bundles', nargs='+', @@ -72,8 +109,21 @@ def _build_arg_parser(): p.add_argument('--colormap', default='jet', help='Select the colormap for colored trk (data_per_point) ' '[%(default)s].') - p.add_argument('--new_labelling', action='store_true', - help='Use the new labelling method (multi-centroids).') + p.add_argument('--hyperplane', action='store_true', + help='Use the hyperplane method (multi-centroids) instead ' + 'of the euclidian method (single-centroid).') + p.add_argument('--use_manhattan', action='store_true', + help='Use the manhattan distance instead of the euclidian ' + 'distance.') + p.add_argument('--skip_uniformize', action='store_true', + help='Skip uniformization of the bundles orientation.') + p.add_argument('--correlation_thr', type=float, const=0.1, nargs='?', + default=0, + help='Threshold for the correlation map. Only for multi ' + 'bundle case. [%(default)s]') + p.add_argument('--streamlines_thr', type=int, const=2, nargs='?', + help='Threshold for the minimum number of streamlines in a ' + 'voxel to be included [%(default)s].') add_reference_arg(p) add_verbose_arg(p) @@ -94,215 +144,158 @@ def main(): sft_centroid = load_tractogram_with_reference(parser, args, args.in_centroid) - sft_centroid.to_vox() - sft_centroid.to_corner() + if args.verbose: + logging.getLogger().setLevel(logging.INFO) + + # TODO check if correlation thr is positive + # TODO check if streamlines thr is positive (above 1) + # When doing longitudinal data, the split in subsection can be done + # on all the bundles at once. Needs to be co-registered. + timer = time.time() sft_list = [] for filename in args.in_bundles: sft = load_tractogram_with_reference(parser, args, filename) if not len(sft.streamlines): - raise IOError('Empty bundle file {}. ' - 'Skipping'.format(args.in_bundle)) + raise IOError(f'Empty bundle file {args.in_bundles}. Skipping.') + if not args.skip_uniformize: + uniformize_bundle_sft(sft, ref_bundle=sft_centroid) sft.to_vox() sft.to_corner() sft_list.append(sft) if len(sft_list): if not is_header_compatible(sft_list[0], sft_list[-1]): - parser.error('Header of {} and {} are not compatible'.format( - args.in_bundles[0], filename)) + parser.error(f'Header of {args.in_bundles[0]} and ' + f'{filename} are not compatible') + + sft_centroid.to_vox() + sft_centroid.to_corner() + logging.info(f'Loaded {len(args.in_bundles)} bundle(s) in ' + f'{round(time.time() - timer, 3)} seconds.') density_list = [] binary_list = [] + timer = time.time() for sft in sft_list: density = compute_tract_counts_map(sft.streamlines, sft.dimensions).astype(float) - binary = np.zeros(sft.dimensions) - binary[density > 0] = 1 + binary = np.zeros(sft.dimensions, dtype=np.uint8) + if args.streamlines_thr is not None: + binary[density >= args.streamlines_thr] = 1 + else: + binary[density > 0] = 1 binary_list.append(binary) density_list.append(density) if not is_header_compatible(sft_centroid, sft_list[0]): - raise IOError('{} and {}do not have a compatible header'.format( - args.in_centroid, args.in_bundle)) + raise IOError(f'{args.in_centroid} and {args.in_bundles} do not have a ' + 'compatible header') + + logging.info('Computed density and binary map(s) in ' + f'{round(time.time() - timer, 3)}.') if len(density_list) > 1: + timer = time.time() corr_map = neighborhood_correlation_(density_list) + logging.info(f'Computed correlation map in ' + f'{round(time.time() - timer, 3)} seconds') else: corr_map = density_list[0].astype(float) corr_map[corr_map > 0] = 1 # Slightly cut the bundle at the edge to clean up single streamline voxels # with no neighbor. Remove isolated voxels to keep a single 'blob' - binary_bundle = np.zeros(corr_map.shape, dtype=bool) - binary_bundle[corr_map > 0.5] = 1 + binary_mask = np.max(binary_list, axis=0) - bundle_disjoint, _ = ndi.label(binary_bundle) - unique, count = np.unique(bundle_disjoint, return_counts=True) - val = unique[np.argmax(count[1:])+1] - binary_bundle[bundle_disjoint != val] = 0 + if args.correlation_thr > 1e-3: + binary_mask[corr_map < args.correlation_thr] = 0 - corr_map = corr_map*binary_bundle - nib.save(nib.Nifti1Image(corr_map, sft_list[0].affine), + if args.streamlines_thr is not None: + bundle_disjoint, _ = ndi.label(binary_mask) + unique, count = np.unique(bundle_disjoint, return_counts=True) + val = unique[np.argmax(count[1:])+1] + binary_mask[bundle_disjoint != val] = 0 + + nib.save(nib.Nifti1Image(corr_map * binary_mask, sft_list[0].affine), os.path.join(args.out_dir, 'correlation_map.nii.gz')) - # Chop off some streamlines + # A bundle must be contiguous, we can't have isolated voxels. + timer = time.time() concat_sft = StatefulTractogram.from_sft([], sft_list[0]) + concat_sft.to_vox() + concat_sft.to_corner() for i in range(len(sft_list)): - sft_list[i] = cut_streamlines_with_mask(sft_list[i], - binary_bundle) + if args.streamlines_thr is not None: + sft_list[i] = cut_streamlines_with_mask(sft_list[i], + binary_mask) + else: + sft_list[i].data_per_streamline = {} + sft_list[i].data_per_point = {} + if len(sft_list[i]): concat_sft += sft_list[i] + logging.info(f'Chop bundle(s) in {round(time.time() - timer, 3)} seconds.') + # Here + method = 'hyperplane' if args.hyperplane else 'centerline' args.nb_pts = len(sft_centroid.streamlines[0]) if args.nb_pts is None \ else args.nb_pts - - sft_centroid = resample_streamlines_num_points(sft_centroid, args.nb_pts) - tmp_sft = resample_streamlines_num_points(concat_sft, args.nb_pts) - - if not args.new_labelling: - new_streamlines = sft_centroid.streamlines.copy() - sft_centroid = StatefulTractogram.from_sft([new_streamlines[0]], - sft_centroid) - else: - srr = StreamlineLinearRegistration() - srm = srr.optimize(static=tmp_sft.streamlines, - moving=sft_centroid.streamlines) - sft_centroid.streamlines = srm.transform(sft_centroid.streamlines) - - uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0]) - labels, dists = min_dist_to_centroid(concat_sft.streamlines._data, - sft_centroid.streamlines._data, - args.nb_pts) - labels += 1 # 0 means no labels - - # It is not allowed that labels jumps labels for consistency - # Streamlines should have continous labels - final_streamlines = [] - final_label = [] - final_dists = [] - curr_ind = 0 - for i, streamline in enumerate(concat_sft.streamlines): - next_ind = curr_ind + len(streamline) - curr_labels = labels[curr_ind:next_ind] - curr_dists = dists[curr_ind:next_ind] - curr_ind = next_ind - - # Flip streamlines so the labels increase (facilitate if/else) - # Should always be ordered in nextflow pipeline - gradient = np.gradient(curr_labels) - if len(np.argwhere(gradient < 0)) > len(np.argwhere(gradient > 0)): - streamline = streamline[::-1] - curr_labels = curr_labels[::-1] - curr_dists = curr_dists[::-1] - - # # Find jumps, cut them and find the longest - gradient = np.ediff1d(curr_labels) - max_jump = max(args.nb_pts // 5, 1) - if len(np.argwhere(np.abs(gradient) > max_jump)) > 0: - pos_jump = np.where(np.abs(gradient) > max_jump)[0] + 1 - split_chunk = np.split(curr_labels, - pos_jump) - - max_len = 0 - max_pos = 0 - for j, chunk in enumerate(split_chunk): - if len(chunk) > max_len: - max_len = len(chunk) - max_pos = j - - curr_labels = split_chunk[max_pos] - gradient_chunk = np.ediff1d(chunk) - if len(np.unique(np.sign(gradient_chunk))) > 1: - continue - streamline = np.split(streamline, - pos_jump)[max_pos] - curr_dists = np.split(curr_dists, - pos_jump)[max_pos] - - final_streamlines.append(streamline) - final_label.append(curr_labels) - final_dists.append(curr_dists) - - final_streamlines = ArraySequence(final_streamlines) - final_labels = ArraySequence(final_label) - final_dists = ArraySequence(final_dists) - - kd_tree = cKDTree(final_streamlines._data) - labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) - distance_map = np.zeros(binary_bundle.shape, dtype=float) - indices = np.array(np.nonzero(binary_bundle), dtype=int).T - - for ind in indices: - _, neighbor_ids = kd_tree.query(ind, k=5) - - if not len(neighbor_ids): - continue - - labels_val = final_labels._data[neighbor_ids] - dists_val = final_dists._data[neighbor_ids] - sum_dists_vox = np.sum(dists_val) - weights_vox = np.exp(-dists_val / sum_dists_vox) - - vote = np.bincount(labels_val, weights=weights_vox) - total = np.arange(np.amax(labels_val+1)) - winner = total[np.argmax(vote)] - labels_map[ind[0], ind[1], ind[2]] = winner - distance_map[ind[0], ind[1], ind[2]] = np.average(dists_val) - - cmap = get_lookup_table(args.colormap) - + labels_map = subdivide_bundles(concat_sft, sft_centroid, binary_mask, + args.nb_pts, method=method) + + timer = time.time() + distance_map = compute_distance_map(labels_map, binary_mask, + args.use_manhattan, args.nb_pts) + logging.info('Computed distance map in ' + f'{round(time.time() - timer, 3)} seconds') + print(binary_mask.shape, np.count_nonzero(binary_mask)) + timer = time.time() + cmap = get_lookup_table(args.colormap) for i, sft in enumerate(sft_list): if len(sft_list) > 1: - sub_out_dir = os.path.join(args.out_dir, 'session_{}'.format(i+1)) + sub_out_dir = os.path.join(args.out_dir, f'session_{i+1}') else: sub_out_dir = args.out_dir + new_sft = StatefulTractogram.from_sft(sft.streamlines, sft_list[0]) + new_sft.data_per_point['color'] = ArraySequence(new_sft.streamlines) if not os.path.isdir(sub_out_dir): os.mkdir(sub_out_dir) - # Save each session map if multiple inputs - nib.save(nib.Nifti1Image((binary_list[i]*labels_map).astype(np.uint16), - sft_list[0].affine), - os.path.join(sub_out_dir, 'labels_map.nii.gz')) - nib.save(nib.Nifti1Image(binary_list[i]*distance_map, - sft_list[0].affine), - os.path.join(sub_out_dir, 'distance_map.nii.gz')) - nib.save(nib.Nifti1Image(binary_list[i]*corr_map, - sft_list[0].affine), - os.path.join(sub_out_dir, 'correlation_map.nii.gz')) - - if len(sft): - tmp_labels = ndi.map_coordinates(labels_map, - sft.streamlines._data.T-0.5, - order=0) - tmp_dists = ndi.map_coordinates(distance_map, - sft.streamlines._data.T-0.5, - order=0) - tmp_corr = ndi.map_coordinates(corr_map, - sft.streamlines._data.T-0.5, - order=0) - cmap = plt.colormaps[args.colormap] - new_sft.data_per_point['color'] = ArraySequence( - new_sft.streamlines) - - # Nicer visualisation for MI-Brain - new_sft.data_per_point['color']._data = cmap( - tmp_labels / np.max(tmp_labels))[:, 0:3] * 255 - save_tractogram(new_sft, - os.path.join(sub_out_dir, 'labels.trk')) - - if len(sft): - new_sft.data_per_point['color']._data = cmap( - tmp_dists / np.max(tmp_dists))[:, 0:3] * 255 - save_tractogram(new_sft, - os.path.join(sub_out_dir, 'distance.trk')) - - if len(sft): - new_sft.data_per_point['color']._data = cmap(tmp_corr)[ - :, 0:3] * 255 - save_tractogram(new_sft, - os.path.join(sub_out_dir, 'correlation.trk')) + # Dictionary to hold the data for each type + data_dict = {'labels': labels_map.astype(np.uint16), + 'distance': distance_map.astype(np.float32), + 'correlation': corr_map.astype(np.float32)} + + # Iterate through each type to save the files + for basename, map in data_dict.items(): + nib.save(nib.Nifti1Image((binary_list[i] * map), sft_list[0].affine), + os.path.join(sub_out_dir, f'{basename}_map.nii.gz')) + + if basename == 'correlation' and len(args.in_bundles) == 1: + continue + + if len(sft): + tmp_data = ndi.map_coordinates( + map, sft.streamlines._data.T - 0.5, order=0) + + if basename == 'labels': + max_val = args.nb_pts + elif basename == 'correlation': + max_val = 1 + else: + max_val = np.max(tmp_data) + max_val = args.nb_pts + new_sft.data_per_point['color']._data = cmap( + tmp_data / max_val)[:, 0:3] * 255 + + # Save the tractogram + save_tractogram(new_sft, + os.path.join(sub_out_dir, + f'{basename}.trk')) + logging.info(f'Saved all data to {args.out_dir} in ' + f'{round(time.time() - timer, 3)} seconds') if __name__ == '__main__': diff --git a/scripts/tests/test_bundle_label_map.py b/scripts/tests/test_bundle_label_map.py index a2e228a35..401760f33 100644 --- a/scripts/tests/test_bundle_label_map.py +++ b/scripts/tests/test_bundle_label_map.py @@ -5,7 +5,8 @@ import tempfile from scilpy import SCILPY_HOME -from scilpy.io.fetcher import fetch_data, get_testing_files_dict +from scilpy.io.fetcher import get_testing_files_dict, fetch_data + # If they already exist, this only takes 5 seconds (check md5sum) fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) @@ -13,15 +14,32 @@ def test_help_option(script_runner): - ret = script_runner.run('scil_bundle_label_map.py', '--help') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', '--help') + assert ret.success + + +def test_execution_tractometry_euclidian(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM.trk') + in_centroid = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM_uni_c_10.trk') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', + in_bundle, in_centroid, + 'results_euc/', + '--colormap', 'viridis') assert ret.success -def test_execution_tractometry(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM.trk') +def test_execution_tractometry_hyperplane(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(SCILPY_HOME, 'tractometry', + 'IFGWM.trk') in_centroid = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM_uni_c_10.trk') - ret = script_runner.run('scil_bundle_label_map.py', in_bundle, in_centroid, - 'results_dir/', '--colormap', 'viridis') + ret = script_runner.run('scil_compute_bundle_voxel_label_map.py', + in_bundle, in_centroid, + 'results_man/', + '--colormap', 'viridis', + '--hyperplane', '--use_manhattan') assert ret.success