diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index 90c527ce3..038581fad 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -3,11 +3,6 @@ from scipy.special import lpmv, gammaln from tqdm import tqdm -from dipy.align import Bunch -from dipy.tracking.local_tracking import (LocalTracking, - ParticleFilteringTracking) -from dipy.tracking.stopping_criterion import StreamlineStatus -import random import math @@ -43,80 +38,6 @@ def spherical_harmonics(m, n, theta, phi): return val -TissueTypes = Bunch(OUTSIDEIMAGE=-1, INVALIDPOINT=0, TRACKPOINT=1, ENDPOINT=2) - - -def _verbose_generate_tractogram(self): - """A streamline generator""" - - # Get inverse transform (lin/offset) for seeds - inv_A = np.linalg.inv(self.affine) - lin = inv_A[:3, :3] - offset = inv_A[:3, 3] - - F = np.empty((self.max_length + 1, 3), dtype=float) - B = F.copy() - for s in tqdm(self.seeds): - s = np.dot(lin, s) + offset - # Set the random seed in numpy and random - if self.random_seed is not None: - s_random_seed = hash(np.abs((np.sum(s)) + self.random_seed)) \ - % (2**32 - 1) - random.seed(s_random_seed) - np.random.seed(s_random_seed) - directions = self.direction_getter.initial_direction(s) - if directions.size == 0 and self.return_all: - # only the seed position - if self.save_seeds: - yield [s], s - else: - yield [s] - directions = directions[:self.max_cross] - for first_step in directions: - stepsF, stream_status = self._tracker(s, first_step, F) - if not (self.return_all - or stream_status == StreamlineStatus.ENDPOINT - or stream_status == StreamlineStatus.OUTSIDEIMAGE): - continue - first_step = -first_step - stepsB, stream_status = self._tracker(s, first_step, B) - if not (self.return_all - or stream_status == StreamlineStatus.ENDPOINT - or stream_status == StreamlineStatus.OUTSIDEIMAGE): - continue - if stepsB == 1: - streamline = F[:stepsF].copy() - else: - parts = (B[stepsB - 1:0:-1], F[:stepsF]) - streamline = np.concatenate(parts, axis=0) - - # move to the next streamline if only the seed position - # and not return all - len_sl = len(streamline) - if len_sl >= self.min_length: - if len_sl <= self.max_length: - if self.save_seeds: - yield streamline, s - else: - yield streamline - - -class VerboseLocalTracking(LocalTracking): - def __init__(self, *args, min_length=10, max_length=1000, **kwargs): - super().__init__(*args, **kwargs) - self.min_length = min_length - self.max_length = max_length - _generate_tractogram = _verbose_generate_tractogram - - -class VerboseParticleFilteringTracking(ParticleFilteringTracking): - def __init__(self, *args, min_length=10, max_length=1000, **kwargs): - super().__init__(*args, **kwargs) - self.min_length = min_length - self.max_length = max_length - _generate_tractogram = _verbose_generate_tractogram - - def in_place_norm(vec, axis=-1, keepdims=False, delvec=True): """ Return Vectors with Euclidean (L2) norm diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index df6feb570..77107fcc6 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -47,8 +47,8 @@ def _meta_from_tracking_params( Parameters=dict( Units="mm", StepSize=tracking_params["step_size"], - MinimumLength=tracking_params["min_length"], - MaximumLength=tracking_params["max_length"], + MinimumLength=tracking_params["minlen"], + MaximumLength=tracking_params["maxlen"], Unidirectional=False), Timing=time() - start_time) return meta diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index 72cb25aca..1b8c030bf 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -497,12 +497,10 @@ def test_API_type_checking(): preproc_pipeline='vistasoft', mapping_definition=IdentityMap(), tracking_params={ - "n_seeds": 1, + "n_seeds": 10, "random_seeds": True, - "directions": "prob", - "odf_model": "CSD"}, - csd_sh_order=2, # reduce CSD fit time - csd_response=((1, 1, 1), 1), + "directions": "det", + "odf_model": "DTI"}, bundle_info=abd.default18_bd()["ARC_L", "ARC_R"]) myafq.export("bundles") del myafq @@ -630,7 +628,7 @@ def test_AFQ_pft(): "stop_mask": stop_mask, "stop_threshold": "CMC", "tracker": "pft", - "max_length": 150, + "maxlen": 150, }) sl_file = myafq.export("streamlines")["01"] dwi_file = myafq.export("dwi")["01"] @@ -640,7 +638,7 @@ def test_AFQ_pft(): bbox_valid_check=False, trk_header_check=False).streamlines for sl in sls: - # double the max_length, due to step size of 0.5 + # double the maxlen, due to step size of 0.5 assert len(sl) <= 300 @@ -782,7 +780,7 @@ def test_AFQ_data_waypoint(): tracking_params = dict(odf_model="csd", seed_mask=RoiImage(), - n_seeds=100, + n_seeds=200, random_seeds=True, rng_seed=42) segmentation_params = dict(filter_by_endpoints=False, @@ -836,14 +834,14 @@ def test_AFQ_data_waypoint(): seg_sft = aus.SegmentedSFT.fromfile( myafq.export("bundles")) - npt.assert_(len(seg_sft.get_bundle('SLF_R').streamlines) > 0) + npt.assert_(len(seg_sft.get_bundle('CST_L').streamlines) > 0) # Test bundles exporting: myafq.export("indiv_bundles") assert op.exists(op.join( myafq.export("results_dir"), 'bundles', - 'sub-01_ses-01_coordsys-RASMM_trkmethod-probCSD_recogmethod-AFQ_desc-SLFR_tractography.trk')) # noqa + 'sub-01_ses-01_coordsys-RASMM_trkmethod-probCSD_recogmethod-AFQ_desc-CSTL_tractography.trk')) # noqa tract_profile_fname = myafq.export("profiles") tract_profiles = pd.read_csv(tract_profile_fname) @@ -884,7 +882,7 @@ def test_AFQ_data_waypoint(): # ROI mask needs to be put in quotes in config tracking_params = dict(odf_model="CSD", seed_mask="RoiImage()", - n_seeds=100, + n_seeds=200, random_seeds=True, rng_seed=42) bundle_dict_as_str = ( diff --git a/AFQ/tests/test_tractography.py b/AFQ/tests/test_tractography.py index 3ea534171..e1eccf0e6 100644 --- a/AFQ/tests/test_tractography.py +++ b/AFQ/tests/test_tractography.py @@ -24,7 +24,7 @@ fdata = op.join(tmpdir.name, 'dti.nii.gz') make_tracking_data(fbval, fbvec, fdata) -min_length = 20 +minlen = 20 step_size = 0.5 @@ -36,7 +36,7 @@ def test_csd_local_tracking(): sh_order=sh_order, lambda_=1, tau=0.1, mask=None, out_dir=tmpdir.name) for directions in ["det", "prob"]: - sl = track( + sls = track( fname, directions, odf_model="CSD", @@ -46,16 +46,17 @@ def test_csd_local_tracking(): n_seeds=seeds, stop_mask=None, step_size=step_size, - min_length=min_length, + minlen=minlen, tracker="local").streamlines - npt.assert_(len(sl[0]) >= step_size * min_length) + for sl in sls: + npt.assert_(len(sl) >= minlen / step_size) def test_dti_local_tracking(): fdict = fit_dti(fdata, fbval, fbvec) for directions in ["det", "prob"]: - sl = track( + sls = track( fdict['params'], directions, max_angle=30., @@ -63,10 +64,11 @@ def test_dti_local_tracking(): seed_mask=None, n_seeds=1, step_size=step_size, - min_length=min_length, + minlen=minlen, odf_model="DTI", tracker="local").streamlines - npt.assert_(len(sl[0]) >= min_length * step_size) + for sl in sls: + npt.assert_(len(sl) >= minlen / step_size) def test_pft_tracking(): @@ -89,7 +91,7 @@ def test_pft_tracking(): for directions in ["det", "prob"]: for stop_threshold in ["ACT", "CMC"]: - sl = track( + sls = track( fname, directions, max_angle=30., @@ -99,10 +101,12 @@ def test_pft_tracking(): stop_threshold=stop_threshold, n_seeds=1, step_size=step_size, - min_length=min_length, + minlen=minlen, odf_model=odf, tracker="pft").streamlines - npt.assert_(len(sl[0]) >= min_length * step_size) + + for sl in sls: + npt.assert_(len(sl) >= minlen / step_size) # Test error handling: with pytest.raises(RuntimeError): @@ -116,7 +120,7 @@ def test_pft_tracking(): stop_threshold=stop_threshold, n_seeds=1, step_size=step_size, - min_length=min_length, + minlen=minlen, tracker="pft") with pytest.raises(RuntimeError): @@ -130,5 +134,5 @@ def test_pft_tracking(): stop_threshold=None, # Stop threshold needs to be a string! n_seeds=1, step_size=step_size, - min_length=min_length, + minlen=minlen, tracker="pft") \ No newline at end of file diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 49126e334..b28632d9c 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -2,6 +2,7 @@ import numpy as np import nibabel as nib import logging +from tqdm import tqdm import dipy.data as dpd from dipy.align import resample @@ -15,8 +16,9 @@ from nibabel.streamlines.tractogram import LazyTractogram -from AFQ._fixes import (VerboseLocalTracking, VerboseParticleFilteringTracking, - tensor_odf) +from dipy.tracking.local_tracking import (LocalTracking, + ParticleFilteringTracking) +from AFQ._fixes import tensor_odf def get_percentile_threshold(mask, threshold): @@ -32,7 +34,7 @@ def get_percentile_threshold(mask, threshold): def track(params_file, directions="prob", max_angle=30., sphere=None, seed_mask=None, seed_threshold=0, thresholds_as_percentages=False, n_seeds=1, random_seeds=False, rng_seed=None, stop_mask=None, - stop_threshold=0, step_size=0.5, min_length=50, max_length=250, + stop_threshold=0, step_size=0.5, minlen=50, maxlen=250, odf_model="CSD", tracker="local", trx=False): """ Tractography @@ -95,9 +97,9 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, Default: False step_size : float, optional. The size of a step (in mm) of tractography. Default: 0.5 - min_length: int, optional + minlen: int, optional The miminal length (mm) in a streamline. Default: 20 - max_length: int, optional + maxlen: int, optional The miminal length (mm) in a streamline. Default: 250 odf_model : str, optional One of {"DTI", "CSD", "DKI"}. Defaults to use "DTI" @@ -135,8 +137,8 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: - min_length = int(min_length / step_size) - max_length = int(max_length / step_size) + minlen = int(minlen / step_size) + maxlen = int(maxlen / step_size) logger.info("Generating Seeds...") if isinstance(n_seeds, int): @@ -192,7 +194,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, stopping_criterion = ThresholdStoppingCriterion(stop_mask, stop_threshold) - my_tracker = VerboseLocalTracking + my_tracker = LocalTracking elif tracker == "pft": if not isinstance(stop_threshold, str): @@ -233,7 +235,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, vox_sizes.append(np.mean(params_img.header.get_zooms()[:3])) - my_tracker = VerboseParticleFilteringTracking + my_tracker = ParticleFilteringTracking if stop_threshold == "CMC": stopping_criterion = CmcStoppingCriterion.from_pve( pve_wm_data, @@ -247,16 +249,17 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, pve_gm_data, pve_csf_data) - logger.info("Tracking...") + logger.info( + f"Tracking with {len(seeds)} seeds, 2 directions per seed...") return _tracking(my_tracker, seeds, dg, stopping_criterion, params_img, - step_size=step_size, min_length=min_length, - max_length=max_length, random_seed=rng_seed, + step_size=step_size, minlen=minlen, + maxlen=maxlen, random_seed=rng_seed, trx=trx) def _tracking(tracker, seeds, dg, stopping_criterion, params_img, - step_size=0.5, min_length=40, max_length=200, + step_size=0.5, minlen=40, maxlen=200, random_seed=None, trx=False): """ Helper function @@ -264,15 +267,16 @@ def _tracking(tracker, seeds, dg, stopping_criterion, params_img, if len(seeds.shape) == 1: seeds = seeds[None, ...] - tracker = tracker( + tracker = tqdm(tracker( dg, stopping_criterion, seeds, params_img.affine, step_size=step_size, - min_length=min_length, - max_length=max_length, - random_seed=random_seed) + minlen=minlen, + maxlen=maxlen, + return_all=False, + random_seed=random_seed)) if trx: return LazyTractogram(lambda: tracker, diff --git a/AFQ/utils/bin.py b/AFQ/utils/bin.py index 794ec1dd6..a2dc21e35 100644 --- a/AFQ/utils/bin.py +++ b/AFQ/utils/bin.py @@ -389,8 +389,8 @@ def generate_json(json_folder, overwrite=False, "tckgen":{ "algorithm": "iFOD2", "select": 1e6, - "max_length": 250, - "min_length": 30, + "maxlen": 250, + "minlen": 30, "power":0.33 }, "sift2":{} diff --git a/setup.cfg b/setup.cfg index 0aef2e34f..cbaa0b5a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,7 @@ python_requires = >=3.8 install_requires = # core packages scikit_image>=0.14.2 - dipy>=1.7.0,<1.8.0 + dipy>=1.8.0,<1.9.0 pandas pybids>=0.16.2 templateflow>=0.8