Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] simplify verbose tracking #1087

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions AFQ/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
from scipy.special import lpmv, gammaln

from tqdm import tqdm
from dipy.align import Bunch
from dipy.tracking.local_tracking import (LocalTracking,
ParticleFilteringTracking)
from dipy.tracking.stopping_criterion import StreamlineStatus
import random

import math

Expand Down Expand Up @@ -43,80 +38,6 @@ def spherical_harmonics(m, n, theta, phi):
return val


TissueTypes = Bunch(OUTSIDEIMAGE=-1, INVALIDPOINT=0, TRACKPOINT=1, ENDPOINT=2)


def _verbose_generate_tractogram(self):
"""A streamline generator"""

# Get inverse transform (lin/offset) for seeds
inv_A = np.linalg.inv(self.affine)
lin = inv_A[:3, :3]
offset = inv_A[:3, 3]

F = np.empty((self.max_length + 1, 3), dtype=float)
B = F.copy()
for s in tqdm(self.seeds):
s = np.dot(lin, s) + offset
# Set the random seed in numpy and random
if self.random_seed is not None:
s_random_seed = hash(np.abs((np.sum(s)) + self.random_seed)) \
% (2**32 - 1)
random.seed(s_random_seed)
np.random.seed(s_random_seed)
directions = self.direction_getter.initial_direction(s)
if directions.size == 0 and self.return_all:
# only the seed position
if self.save_seeds:
yield [s], s
else:
yield [s]
directions = directions[:self.max_cross]
for first_step in directions:
stepsF, stream_status = self._tracker(s, first_step, F)
if not (self.return_all
or stream_status == StreamlineStatus.ENDPOINT
or stream_status == StreamlineStatus.OUTSIDEIMAGE):
continue
first_step = -first_step
stepsB, stream_status = self._tracker(s, first_step, B)
if not (self.return_all
or stream_status == StreamlineStatus.ENDPOINT
or stream_status == StreamlineStatus.OUTSIDEIMAGE):
continue
if stepsB == 1:
streamline = F[:stepsF].copy()
else:
parts = (B[stepsB - 1:0:-1], F[:stepsF])
streamline = np.concatenate(parts, axis=0)

# move to the next streamline if only the seed position
# and not return all
len_sl = len(streamline)
if len_sl >= self.min_length:
if len_sl <= self.max_length:
if self.save_seeds:
yield streamline, s
else:
yield streamline


class VerboseLocalTracking(LocalTracking):
def __init__(self, *args, min_length=10, max_length=1000, **kwargs):
super().__init__(*args, **kwargs)
self.min_length = min_length
self.max_length = max_length
_generate_tractogram = _verbose_generate_tractogram


class VerboseParticleFilteringTracking(ParticleFilteringTracking):
def __init__(self, *args, min_length=10, max_length=1000, **kwargs):
super().__init__(*args, **kwargs)
self.min_length = min_length
self.max_length = max_length
_generate_tractogram = _verbose_generate_tractogram


def in_place_norm(vec, axis=-1, keepdims=False, delvec=True):
""" Return Vectors with Euclidean (L2) norm

Expand Down
4 changes: 2 additions & 2 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def _meta_from_tracking_params(
Parameters=dict(
Units="mm",
StepSize=tracking_params["step_size"],
MinimumLength=tracking_params["min_length"],
MaximumLength=tracking_params["max_length"],
MinimumLength=tracking_params["minlen"],
MaximumLength=tracking_params["maxlen"],
Unidirectional=False),
Timing=time() - start_time)
return meta
Expand Down
20 changes: 9 additions & 11 deletions AFQ/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,10 @@ def test_API_type_checking():
preproc_pipeline='vistasoft',
mapping_definition=IdentityMap(),
tracking_params={
"n_seeds": 1,
"n_seeds": 10,
"random_seeds": True,
"directions": "prob",
"odf_model": "CSD"},
csd_sh_order=2, # reduce CSD fit time
csd_response=((1, 1, 1), 1),
"directions": "det",
"odf_model": "DTI"},
bundle_info=abd.default18_bd()["ARC_L", "ARC_R"])
myafq.export("bundles")
del myafq
Expand Down Expand Up @@ -630,7 +628,7 @@ def test_AFQ_pft():
"stop_mask": stop_mask,
"stop_threshold": "CMC",
"tracker": "pft",
"max_length": 150,
"maxlen": 150,
})
sl_file = myafq.export("streamlines")["01"]
dwi_file = myafq.export("dwi")["01"]
Expand All @@ -640,7 +638,7 @@ def test_AFQ_pft():
bbox_valid_check=False,
trk_header_check=False).streamlines
for sl in sls:
# double the max_length, due to step size of 0.5
# double the maxlen, due to step size of 0.5
assert len(sl) <= 300


Expand Down Expand Up @@ -782,7 +780,7 @@ def test_AFQ_data_waypoint():

tracking_params = dict(odf_model="csd",
seed_mask=RoiImage(),
n_seeds=100,
n_seeds=200,
random_seeds=True,
rng_seed=42)
segmentation_params = dict(filter_by_endpoints=False,
Expand Down Expand Up @@ -836,14 +834,14 @@ def test_AFQ_data_waypoint():

seg_sft = aus.SegmentedSFT.fromfile(
myafq.export("bundles"))
npt.assert_(len(seg_sft.get_bundle('SLF_R').streamlines) > 0)
npt.assert_(len(seg_sft.get_bundle('CST_L').streamlines) > 0)

# Test bundles exporting:
myafq.export("indiv_bundles")
assert op.exists(op.join(
myafq.export("results_dir"),
'bundles',
'sub-01_ses-01_coordsys-RASMM_trkmethod-probCSD_recogmethod-AFQ_desc-SLFR_tractography.trk')) # noqa
'sub-01_ses-01_coordsys-RASMM_trkmethod-probCSD_recogmethod-AFQ_desc-CSTL_tractography.trk')) # noqa

tract_profile_fname = myafq.export("profiles")
tract_profiles = pd.read_csv(tract_profile_fname)
Expand Down Expand Up @@ -884,7 +882,7 @@ def test_AFQ_data_waypoint():
# ROI mask needs to be put in quotes in config
tracking_params = dict(odf_model="CSD",
seed_mask="RoiImage()",
n_seeds=100,
n_seeds=200,
random_seeds=True,
rng_seed=42)
bundle_dict_as_str = (
Expand Down
28 changes: 16 additions & 12 deletions AFQ/tests/test_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
fdata = op.join(tmpdir.name, 'dti.nii.gz')
make_tracking_data(fbval, fbvec, fdata)

min_length = 20
minlen = 20
step_size = 0.5


Expand All @@ -36,7 +36,7 @@ def test_csd_local_tracking():
sh_order=sh_order, lambda_=1, tau=0.1, mask=None,
out_dir=tmpdir.name)
for directions in ["det", "prob"]:
sl = track(
sls = track(
fname,
directions,
odf_model="CSD",
Expand All @@ -46,27 +46,29 @@ def test_csd_local_tracking():
n_seeds=seeds,
stop_mask=None,
step_size=step_size,
min_length=min_length,
minlen=minlen,
tracker="local").streamlines

npt.assert_(len(sl[0]) >= step_size * min_length)
for sl in sls:
npt.assert_(len(sl) >= minlen / step_size)


def test_dti_local_tracking():
fdict = fit_dti(fdata, fbval, fbvec)
for directions in ["det", "prob"]:
sl = track(
sls = track(
fdict['params'],
directions,
max_angle=30.,
sphere=None,
seed_mask=None,
n_seeds=1,
step_size=step_size,
min_length=min_length,
minlen=minlen,
odf_model="DTI",
tracker="local").streamlines
npt.assert_(len(sl[0]) >= min_length * step_size)
for sl in sls:
npt.assert_(len(sl) >= minlen / step_size)


def test_pft_tracking():
Expand All @@ -89,7 +91,7 @@ def test_pft_tracking():

for directions in ["det", "prob"]:
for stop_threshold in ["ACT", "CMC"]:
sl = track(
sls = track(
fname,
directions,
max_angle=30.,
Expand All @@ -99,10 +101,12 @@ def test_pft_tracking():
stop_threshold=stop_threshold,
n_seeds=1,
step_size=step_size,
min_length=min_length,
minlen=minlen,
odf_model=odf,
tracker="pft").streamlines
npt.assert_(len(sl[0]) >= min_length * step_size)

for sl in sls:
npt.assert_(len(sl) >= minlen / step_size)

# Test error handling:
with pytest.raises(RuntimeError):
Expand All @@ -116,7 +120,7 @@ def test_pft_tracking():
stop_threshold=stop_threshold,
n_seeds=1,
step_size=step_size,
min_length=min_length,
minlen=minlen,
tracker="pft")

with pytest.raises(RuntimeError):
Expand All @@ -130,5 +134,5 @@ def test_pft_tracking():
stop_threshold=None, # Stop threshold needs to be a string!
n_seeds=1,
step_size=step_size,
min_length=min_length,
minlen=minlen,
tracker="pft")
38 changes: 21 additions & 17 deletions AFQ/tractography/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import nibabel as nib
import logging
from tqdm import tqdm

import dipy.data as dpd
from dipy.align import resample
Expand All @@ -15,8 +16,9 @@

from nibabel.streamlines.tractogram import LazyTractogram

from AFQ._fixes import (VerboseLocalTracking, VerboseParticleFilteringTracking,
tensor_odf)
from dipy.tracking.local_tracking import (LocalTracking,
ParticleFilteringTracking)
from AFQ._fixes import tensor_odf


def get_percentile_threshold(mask, threshold):
Expand All @@ -32,7 +34,7 @@ def get_percentile_threshold(mask, threshold):
def track(params_file, directions="prob", max_angle=30., sphere=None,
seed_mask=None, seed_threshold=0, thresholds_as_percentages=False,
n_seeds=1, random_seeds=False, rng_seed=None, stop_mask=None,
stop_threshold=0, step_size=0.5, min_length=50, max_length=250,
stop_threshold=0, step_size=0.5, minlen=50, maxlen=250,
odf_model="CSD", tracker="local", trx=False):
"""
Tractography
Expand Down Expand Up @@ -95,9 +97,9 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
Default: False
step_size : float, optional.
The size of a step (in mm) of tractography. Default: 0.5
min_length: int, optional
minlen: int, optional
The miminal length (mm) in a streamline. Default: 20
max_length: int, optional
maxlen: int, optional
The miminal length (mm) in a streamline. Default: 250
odf_model : str, optional
One of {"DTI", "CSD", "DKI"}. Defaults to use "DTI"
Expand Down Expand Up @@ -135,8 +137,8 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,

# We need to calculate the size of a voxel, so we can transform
# from mm to voxel units:
min_length = int(min_length / step_size)
max_length = int(max_length / step_size)
minlen = int(minlen / step_size)
maxlen = int(maxlen / step_size)

logger.info("Generating Seeds...")
if isinstance(n_seeds, int):
Expand Down Expand Up @@ -192,7 +194,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
stopping_criterion = ThresholdStoppingCriterion(stop_mask,
stop_threshold)

my_tracker = VerboseLocalTracking
my_tracker = LocalTracking

elif tracker == "pft":
if not isinstance(stop_threshold, str):
Expand Down Expand Up @@ -233,7 +235,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,

vox_sizes.append(np.mean(params_img.header.get_zooms()[:3]))

my_tracker = VerboseParticleFilteringTracking
my_tracker = ParticleFilteringTracking
if stop_threshold == "CMC":
stopping_criterion = CmcStoppingCriterion.from_pve(
pve_wm_data,
Expand All @@ -247,32 +249,34 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
pve_gm_data,
pve_csf_data)

logger.info("Tracking...")
logger.info(
f"Tracking with {len(seeds)} seeds, 2 directions per seed...")

return _tracking(my_tracker, seeds, dg, stopping_criterion, params_img,
step_size=step_size, min_length=min_length,
max_length=max_length, random_seed=rng_seed,
step_size=step_size, minlen=minlen,
maxlen=maxlen, random_seed=rng_seed,
trx=trx)


def _tracking(tracker, seeds, dg, stopping_criterion, params_img,
step_size=0.5, min_length=40, max_length=200,
step_size=0.5, minlen=40, maxlen=200,
random_seed=None, trx=False):
"""
Helper function
"""
if len(seeds.shape) == 1:
seeds = seeds[None, ...]

tracker = tracker(
tracker = tqdm(tracker(
dg,
stopping_criterion,
seeds,
params_img.affine,
step_size=step_size,
min_length=min_length,
max_length=max_length,
random_seed=random_seed)
minlen=minlen,
maxlen=maxlen,
return_all=False,
random_seed=random_seed))

if trx:
return LazyTractogram(lambda: tracker,
Expand Down
4 changes: 2 additions & 2 deletions AFQ/utils/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def generate_json(json_folder, overwrite=False,
"tckgen":{
"algorithm": "iFOD2",
"select": 1e6,
"max_length": 250,
"min_length": 30,
"maxlen": 250,
"minlen": 30,
"power":0.33
},
"sift2":{}
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ python_requires = >=3.8
install_requires =
# core packages
scikit_image>=0.14.2
dipy>=1.7.0,<1.8.0
dipy>=1.8.0,<1.9.0
pandas
pybids>=0.16.2
templateflow>=0.8
Expand Down
Loading