diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index eabc079d..aee4852b 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -302,7 +302,6 @@ def _compute_origin_finish_blocs(streamlines, volume_size, nb_blocs): def compute_triu_connectivity_from_labels(streamlines, data_labels, - binary: bool = False, use_scilpy=False): """ Compute a connectivity matrix. @@ -313,8 +312,6 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, Streamlines, in vox space, corner origin. data_labels: np.ndarray The loaded nifti image. - binary: bool - If True, return a binary matrix. use_scilpy: bool If True, uses scilpy's method: 'Strategy is to keep the longest streamline segment @@ -380,6 +377,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, start_labels.append(start) end_labels.append(end) + matrix[start, end] += 1 if start != end: matrix[end, start] += 1 @@ -387,9 +385,6 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, matrix = np.triu(matrix) assert matrix.sum() == len(streamlines) - if binary: - matrix = matrix.astype(bool) - return matrix, real_labels, start_labels, end_labels @@ -463,9 +458,11 @@ def prepare_figure_connectivity(matrix): axs[1, 1].imshow(matrix) axs[1, 1].set_title("Binary") + plt.suptitle("All versions of the connectivity matrix.") + def find_streamlines_with_chosen_connectivity( - streamlines, label1, label2, start_labels, end_labels): + streamlines, start_labels, end_labels, label1, label2=None): """ Returns streamlines corresponding to a (label1, label2) or (label2, label1) connection. @@ -474,19 +471,32 @@ def find_streamlines_with_chosen_connectivity( ---------- streamlines: list of np arrays or list of tensors. Streamlines, in vox space, corner origin. - label1: int - The bloc of interest, either as starting or finishing point. - label2: int - The bloc of interest, either as starting or finishing point. start_labels: list[int] The starting bloc for each streamline. end_labels: list[int] The ending bloc for each streamline. + label1: int + The bloc of interest, either as starting or finishing point. + label2: int, optional + The bloc of interest, either as starting or finishing point. + If label2 is None, then all connections (label1, Y) and (X, label1) + are found. """ + start_labels = np.asarray(start_labels) + end_labels = np.asarray(end_labels) - str_ind1 = np.logical_and(start_labels == label1, - end_labels == label2) - str_ind2 = np.logical_and(start_labels == label2, - end_labels == label1) - str_ind = np.logical_or(str_ind1, str_ind2) - return [s for i, s in enumerate(streamlines) if str_ind[i]] + if label2 is None: + labels2 = np.unique(np.concatenate((start_labels[:], end_labels[:]))) + else: + labels2 = [label2] + + found = np.zeros(len(streamlines)) + for label2 in labels2: + str_ind1 = np.logical_and(start_labels == label1, + end_labels == label2) + str_ind2 = np.logical_and(start_labels == label2, + end_labels == label1) + str_ind = np.logical_or(str_ind1, str_ind2) + found = np.logical_or(found, str_ind) + + return [s for i, s in enumerate(streamlines) if found[i]] diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py index b7f42f79..2f76885a 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py @@ -99,7 +99,7 @@ def main(): i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape) print("Saving biggest bundle: {} streamlines.".format(matrix[i, j])) biggest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, start_blocs, end_blocs, i, j) sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_biggest) @@ -108,7 +108,7 @@ def main(): i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape) print("Saving smallest bundle: {} streamlines.".format(matrix[i, j])) biggest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, start_blocs, end_blocs, i, j) sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_smallest) diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py index 9c682512..877653fd 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py @@ -39,7 +39,7 @@ def _build_arg_parser(): "streamline count is saved.") p.add_argument('--show_now', action='store_true', help="If set, shows the matrix with matplotlib.") - p.add_argument('--hide_background', nargs='?', const=0, type=float, + p.add_argument('--hide_background', nargs='?', const=0, type=int, help="If true, set the connectivity matrix for chosen " "label (default: 0), to 0.") p.add_argument( @@ -80,10 +80,11 @@ def main(): p.error("--out_file should have a .npy extension.") out_fig = tmp + '.png' - out_fig_noback = tmp + '_hidden_background.png' out_ordered_labels = tmp + '_labels.txt' + out_rejected_streamlines = tmp + '_rejected_from_background.trk' assert_inputs_exist(p, [args.in_labels, args.streamlines]) - assert_outputs_exist(p, args, [args.out_file, out_fig, out_fig_noback], + assert_outputs_exist(p, args, + [args.out_file, out_fig, out_rejected_streamlines], [args.save_biggest, args.save_smallest]) ext = os.path.splitext(args.streamlines)[1] @@ -101,23 +102,26 @@ def main(): in_sft.to_vox() in_sft.to_corner() - matrix, ordered_labels, start_blocs, end_blocs = \ + matrix, ordered_labels, start_labels, end_labels = \ compute_triu_connectivity_from_labels( in_sft.streamlines, data_labels, use_scilpy=args.use_longest_segment) - prepare_figure_connectivity(matrix) - plt.savefig(out_fig) - if args.hide_background is not None: idx = ordered_labels.index(args.hide_background) nb_hidden = np.sum(matrix[idx, :]) + np.sum(matrix[:, idx]) - \ matrix[idx, idx] if nb_hidden > 0: - logging.info("CAREFUL! Deleting from the matrix {} streamlines " - "with one or both endpoints in a non-labelled area " - "(background = {}; line/column {})" - .format(nb_hidden, args.hide_background, idx)) + logging.warning("CAREFUL! Deleting from the matrix {} streamlines " + "with one or both endpoints in a non-labelled " + "area (background = {}; line/column {})" + .format(nb_hidden, args.hide_background, idx)) + rejected = find_streamlines_with_chosen_connectivity( + in_sft.streamlines, start_labels, end_labels, idx) + logging.info("Saving rejected streamlines in {}" + .format(out_rejected_streamlines)) + sft = in_sft.from_sft(rejected, in_sft) + save_tractogram(sft, out_rejected_streamlines) else: logging.info("No streamlines with endpoints in the background :)") matrix[idx, :] = 0 @@ -125,8 +129,9 @@ def main(): ordered_labels[idx] = ("Hidden background ({})" .format(args.hide_background)) - prepare_figure_connectivity(matrix) - plt.savefig(out_fig_noback) + # Save figure will all versions of the matrix. + prepare_figure_connectivity(matrix) + plt.savefig(out_fig) if args.binary: matrix = matrix > 0 @@ -143,7 +148,7 @@ def main(): .format(matrix[i, j], ordered_labels[i], ordered_labels[j], i, j)) biggest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, i, j, start_labels, end_labels) sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_biggest) @@ -155,7 +160,7 @@ def main(): .format(matrix[i, j], ordered_labels[i], ordered_labels[j], i, j)) smallest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, i, j, start_labels, end_labels) sft = in_sft.from_sft(smallest, in_sft) save_tractogram(sft, args.save_smallest)