Skip to content

Commit

Permalink
try this
Browse files Browse the repository at this point in the history
  • Loading branch information
36000 committed Jan 19, 2024
1 parent 9190ca7 commit d2f05b2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
1 change: 1 addition & 0 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def gpu_tractography(data_imap, tracking_params, seed, stop,
sft = gpu_track(
data_imap["data"], data_imap["gtab"],
nib.load(seed), nib.load(stop),
tracking_params["odf_model"],
tracking_params["seed_threshold"],
tracking_params["stop_threshold"],
tracking_params["thresholds_as_percentages"],
Expand Down
51 changes: 39 additions & 12 deletions AFQ/tractography/gputractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm

from dipy.data import small_sphere
from dipy.reconst.shm import CsaOdfModel
from dipy.reconst.shm import OpdtModel, CsaOdfModel
from dipy.reconst import shm
from dipy.tracking import utils
from dipy.io.stateful_tractogram import StatefulTractogram, Space
Expand All @@ -16,7 +16,7 @@


# Modified from https://github.com/dipy/GPUStreamlines/blob/master/run_dipy_gpu.py
def gpu_track(data, gtab, seed_img, stop_img,
def gpu_track(data, gtab, seed_img, stop_img, odf_model,
seed_threshold, stop_threshold, thresholds_as_percentages,
max_angle, step_size, sampling_density, ngpus):
"""
Expand All @@ -34,6 +34,8 @@ def gpu_track(data, gtab, seed_img, stop_img,
stop_img : Nifti1Image
A float or binary mask that determines a stopping criterion
(e.g. FA).
odf_model : str, optional
One of {"OPDT", "CSA"}
seed_threshold : float
The value of the seed_img above which tracking is seeded.
stop_threshold : float
Expand Down Expand Up @@ -62,8 +64,6 @@ def gpu_track(data, gtab, seed_img, stop_img,
seed_data = seed_img.get_fdata()
stop_data = stop_img.get_fdata()

data = np.ascontiguousarray(data, dtype=np.float64)

if len(np.unique(seed_data)) > 2:
if thresholds_as_percentages:
seed_threshold = get_percentile_threshold(
Expand All @@ -74,12 +74,25 @@ def gpu_track(data, gtab, seed_img, stop_img,
stop_threshold = get_percentile_threshold(
stop_data, stop_threshold)

model = CsaOdfModel(
gtab, sh_order=sh_order,
smooth=0.006, min_signal=1)
fit_matrix = model._fit_matrix
delta_b = fit_matrix
delta_q = fit_matrix
if odf_model.lower() == "opdt":
model = OpdtModel(
gtab,
sh_order=sh_order,
smooth=0.006,
min_signal=1)
fit_matrix = model._fit_matrix
delta_b, delta_q = fit_matrix
elif odf_model.lower() == "csa":
model = CsaOdfModel(
gtab, sh_order=sh_order,
smooth=0.006, min_signal=1)
fit_matrix = model._fit_matrix
delta_b = fit_matrix
delta_q = fit_matrix
else:
raise ValueError((
f"odf_model must be 'opdt' or "
"'csa', not {odf_model}"))

sphere = small_sphere
theta = sphere.theta
Expand All @@ -94,6 +107,20 @@ def gpu_track(data, gtab, seed_img, stop_img,
H = shm.hat(B)
R = shm.lcr_matrix(H)

data = np.ascontiguousarray(data, dtype=np.float64)
H = np.ascontiguousarray(H, dtype=np.float64)
R = np.ascontiguousarray(R, dtype=np.float64)
delta_b = np.ascontiguousarray(delta_b, dtype=np.float64)
delta_q = np.ascontiguousarray(delta_q, dtype=np.float64)
b0s_mask = np.ascontiguousarray(b0s_mask, dtype=np.int32)
stop_data = np.ascontiguousarray(stop_data, dtype=np.float64)
sampling_matrix = np.ascontiguousarray(
sampling_matrix, dtype=np.float64)
sph_verticies = np.ascontiguousarray(
sphere.vertices, dtype=np.float64)
sph_edges = np.ascontiguousarray(
sphere.edges, dtype=np.int32)

gpu_tracker = cuslines.GPUTracker(
cuslines.ModelType.CSAODF,
radians(max_angle),
Expand All @@ -103,9 +130,9 @@ def gpu_track(data, gtab, seed_img, stop_img,
0.25, # relative peak threshold
radians(45), # min separation angle
data, H, R, delta_b, delta_q,
b0s_mask.astype(np.int32), stop_data.astype(np.float64),
b0s_mask, stop_data,
sampling_matrix,
sphere.vertices, sphere.edges.astype(np.int32),
sph_verticies, sph_edges,
ngpus=ngpus, rng_seed=0)

seed_mask = utils.seeds_from_mask(
Expand Down

0 comments on commit d2f05b2

Please sign in to comment.