From d2f05b2a0b8a13a774009d2858d87970089deb9c Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 18 Jan 2024 20:55:45 -0800 Subject: [PATCH] try this --- AFQ/tasks/tractography.py | 1 + AFQ/tractography/gputractography.py | 51 ++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 4be787799..d8868cd0f 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -224,6 +224,7 @@ def gpu_tractography(data_imap, tracking_params, seed, stop, sft = gpu_track( data_imap["data"], data_imap["gtab"], nib.load(seed), nib.load(stop), + tracking_params["odf_model"], tracking_params["seed_threshold"], tracking_params["stop_threshold"], tracking_params["thresholds_as_percentages"], diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index ee2678048..0e38e1b90 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -5,7 +5,7 @@ from tqdm import tqdm from dipy.data import small_sphere -from dipy.reconst.shm import CsaOdfModel +from dipy.reconst.shm import OpdtModel, CsaOdfModel from dipy.reconst import shm from dipy.tracking import utils from dipy.io.stateful_tractogram import StatefulTractogram, Space @@ -16,7 +16,7 @@ # Modified from https://github.com/dipy/GPUStreamlines/blob/master/run_dipy_gpu.py -def gpu_track(data, gtab, seed_img, stop_img, +def gpu_track(data, gtab, seed_img, stop_img, odf_model, seed_threshold, stop_threshold, thresholds_as_percentages, max_angle, step_size, sampling_density, ngpus): """ @@ -34,6 +34,8 @@ def gpu_track(data, gtab, seed_img, stop_img, stop_img : Nifti1Image A float or binary mask that determines a stopping criterion (e.g. FA). + odf_model : str, optional + One of {"OPDT", "CSA"} seed_threshold : float The value of the seed_img above which tracking is seeded. stop_threshold : float @@ -62,8 +64,6 @@ def gpu_track(data, gtab, seed_img, stop_img, seed_data = seed_img.get_fdata() stop_data = stop_img.get_fdata() - data = np.ascontiguousarray(data, dtype=np.float64) - if len(np.unique(seed_data)) > 2: if thresholds_as_percentages: seed_threshold = get_percentile_threshold( @@ -74,12 +74,25 @@ def gpu_track(data, gtab, seed_img, stop_img, stop_threshold = get_percentile_threshold( stop_data, stop_threshold) - model = CsaOdfModel( - gtab, sh_order=sh_order, - smooth=0.006, min_signal=1) - fit_matrix = model._fit_matrix - delta_b = fit_matrix - delta_q = fit_matrix + if odf_model.lower() == "opdt": + model = OpdtModel( + gtab, + sh_order=sh_order, + smooth=0.006, + min_signal=1) + fit_matrix = model._fit_matrix + delta_b, delta_q = fit_matrix + elif odf_model.lower() == "csa": + model = CsaOdfModel( + gtab, sh_order=sh_order, + smooth=0.006, min_signal=1) + fit_matrix = model._fit_matrix + delta_b = fit_matrix + delta_q = fit_matrix + else: + raise ValueError(( + f"odf_model must be 'opdt' or " + "'csa', not {odf_model}")) sphere = small_sphere theta = sphere.theta @@ -94,6 +107,20 @@ def gpu_track(data, gtab, seed_img, stop_img, H = shm.hat(B) R = shm.lcr_matrix(H) + data = np.ascontiguousarray(data, dtype=np.float64) + H = np.ascontiguousarray(H, dtype=np.float64) + R = np.ascontiguousarray(R, dtype=np.float64) + delta_b = np.ascontiguousarray(delta_b, dtype=np.float64) + delta_q = np.ascontiguousarray(delta_q, dtype=np.float64) + b0s_mask = np.ascontiguousarray(b0s_mask, dtype=np.int32) + stop_data = np.ascontiguousarray(stop_data, dtype=np.float64) + sampling_matrix = np.ascontiguousarray( + sampling_matrix, dtype=np.float64) + sph_verticies = np.ascontiguousarray( + sphere.vertices, dtype=np.float64) + sph_edges = np.ascontiguousarray( + sphere.edges, dtype=np.int32) + gpu_tracker = cuslines.GPUTracker( cuslines.ModelType.CSAODF, radians(max_angle), @@ -103,9 +130,9 @@ def gpu_track(data, gtab, seed_img, stop_img, 0.25, # relative peak threshold radians(45), # min separation angle data, H, R, delta_b, delta_q, - b0s_mask.astype(np.int32), stop_data.astype(np.float64), + b0s_mask, stop_data, sampling_matrix, - sphere.vertices, sphere.edges.astype(np.int32), + sph_verticies, sph_edges, ngpus=ngpus, rng_seed=0) seed_mask = utils.seeds_from_mask(