Skip to content

Commit

Permalink
Merge pull request dipy#3254 from deka27/keyword-direction
Browse files Browse the repository at this point in the history
NF: Applying Decorators in Module (Direction)
  • Loading branch information
skoudoro authored Sep 4, 2024
2 parents f0f3dd3 + ead37a6 commit 9892997
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 101 deletions.
4 changes: 2 additions & 2 deletions dipy/direction/bootstrap_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ cdef class BootDirectionGetter(DirectionGetter):
Angular threshold for excluding ODF peaks.
"""
return cls(data, model, max_angle, sphere, max_attempts, sh_order,
b_tol, **kwargs)
return cls(data, model, max_angle, sphere=sphere, max_attempts=max_attempts, sh_order=sh_order,
b_tol=b_tol, **kwargs)


cpdef cnp.ndarray[cnp.float_t, ndim=2] initial_direction(self,
Expand Down
6 changes: 3 additions & 3 deletions dipy/direction/closest_peak_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ cdef class PmfGenDirectionGetter(BasePmfDirectionGetter):
raise ValueError(msg)

pmf_gen = SimplePmfGen(np.asarray(pmf,dtype=float), sphere)
return cls(pmf_gen, max_angle, sphere, pmf_threshold, **kwargs)
return cls(pmf_gen, max_angle, sphere, pmf_threshold=pmf_threshold, **kwargs)

@classmethod
def from_shcoeff(cls, shcoeff, max_angle, sphere=default_sphere,
pmf_threshold=0.1, basis_type=None, legacy=True,
pmf_threshold=0.1, basis_type=None, legacy=True,
sh_to_pmf=False, **kwargs):
"""Probabilistic direction getter from a distribution of directions
on the sphere
Expand Down Expand Up @@ -219,7 +219,7 @@ cdef class PmfGenDirectionGetter(BasePmfDirectionGetter):
else:
pmf_gen = SHCoeffPmfGen(np.asarray(shcoeff,dtype=float), sphere,
basis_type, legacy=legacy)
return cls(pmf_gen, max_angle, sphere, pmf_threshold, **kwargs)
return cls(pmf_gen, max_angle, sphere, pmf_threshold=pmf_threshold, **kwargs)


cdef class ClosestPeakDirectionGetter(PmfGenDirectionGetter):
Expand Down
46 changes: 32 additions & 14 deletions dipy/direction/peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
search_descending,
)
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.multiproc import determine_num_processes


@warning_for_keywords()
def peak_directions_nl(
sphere_eval,
*,
relative_peak_threshold=0.25,
min_separation_angle=25,
sphere=default_sphere,
Expand Down Expand Up @@ -95,8 +98,14 @@ def _helper(x):
return directions, values


@warning_for_keywords()
def peak_directions(
odf, sphere, relative_peak_threshold=0.5, min_separation_angle=25, is_symmetric=True
odf,
sphere,
*,
relative_peak_threshold=0.5,
min_separation_angle=25,
is_symmetric=True,
):
"""Get the directions of odf peaks.
Expand Down Expand Up @@ -420,29 +429,31 @@ def _peaks_from_model_parallel_sub(args):
sphere,
relative_peak_threshold,
min_separation_angle,
mask,
return_odf,
return_sh,
gfa_thr,
normalize_peaks,
sh_order,
sh_basis_type,
legacy,
npeaks,
B,
invB,
mask=mask,
return_odf=return_odf,
return_sh=return_sh,
gfa_thr=gfa_thr,
normalize_peaks=normalize_peaks,
sh_order_max=sh_order,
sh_basis_type=sh_basis_type,
legacy=legacy,
npeaks=npeaks,
B=B,
invB=invB,
parallel=False,
num_processes=None,
)


@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def peaks_from_model(
model,
data,
sphere,
relative_peak_threshold,
min_separation_angle,
*,
mask=None,
return_odf=False,
return_sh=True,
Expand Down Expand Up @@ -601,7 +612,10 @@ def peaks_from_model(

# Get peaks of odf
direction, pk, ind = peak_directions(
odf, sphere, relative_peak_threshold, min_separation_angle
odf,
sphere,
relative_peak_threshold=relative_peak_threshold,
min_separation_angle=min_separation_angle,
)

# Calculate peak metrics
Expand Down Expand Up @@ -713,7 +727,11 @@ def peaks_from_positions(
for i, s in enumerate(vox_positions):
odf = trilinear_interpolate4d(odfs, s)
peaks, _, _ = peak_directions(
odf, sphere, relative_peak_threshold, min_separation_angle, is_symmetric
odf,
sphere,
relative_peak_threshold=relative_peak_threshold,
min_separation_angle=min_separation_angle,
is_symmetric=is_symmetric,
)
nbr_peaks = min(npeaks, peaks.shape[0])
peaks_arr[i, :nbr_peaks, :] = peaks[:nbr_peaks, :]
Expand Down
6 changes: 2 additions & 4 deletions dipy/direction/probabilistic_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ cdef class ProbabilisticDirectionGetter(PmfGenDirectionGetter):
directions more than ``max_angle`` degrees from the incoming direction are
set to 0 and the result is normalized.
"""

def __init__(self, pmf_gen, max_angle, sphere, pmf_threshold=.1, **kwargs):
"""Direction getter from a pmf generator.
Expand Down Expand Up @@ -57,7 +56,7 @@ cdef class ProbabilisticDirectionGetter(PmfGenDirectionGetter):
"""
PmfGenDirectionGetter.__init__(self, pmf_gen, max_angle, sphere,
pmf_threshold, **kwargs)
pmf_threshold=pmf_threshold, **kwargs)
# The vertices need to be in a contiguous array
self.vertices = self.sphere.vertices.copy()

Expand Down Expand Up @@ -128,10 +127,9 @@ cdef class DeterministicMaximumDirectionGetter(ProbabilisticDirectionGetter):
"""Return direction of a sphere with the highest probability mass
function (pmf).
"""

def __init__(self, pmf_gen, max_angle, sphere, pmf_threshold=.1, **kwargs):
ProbabilisticDirectionGetter.__init__(self, pmf_gen, max_angle, sphere,
pmf_threshold, **kwargs)
pmf_threshold=pmf_threshold, **kwargs)

cdef int get_direction_c(self, double[::1] point, double[::1] direction):
"""Find direction with the highest pmf to updates ``direction`` array
Expand Down
2 changes: 1 addition & 1 deletion dipy/direction/ptt_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ cdef class PTTDirectionGetter(ProbabilisticDirectionGetter):
self.rejection_sampling_nbr_sample = 10 # Adaptively set in Trekker.

ProbabilisticDirectionGetter.__init__(self, pmf_gen, max_angle, sphere,
pmf_threshold, **kwargs)
pmf_threshold=pmf_threshold, **kwargs)


cdef void initialize_candidate(self, double[:] init_dir):
Expand Down
26 changes: 23 additions & 3 deletions dipy/direction/tests/test_bootstrap_direction_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ def test_bdg_residual(rng):
)
csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order_max=6)
npt.assert_raises(
ValueError, BootDirectionGetter, data, csd_model, 60, hsph_updated, 6
ValueError,
BootDirectionGetter,
data,
csd_model,
60,
sphere=hsph_updated,
max_attempts=6,
)


Expand Down Expand Up @@ -291,8 +297,22 @@ def test_boot_pmf():
category=PendingDeprecationWarning,
)
npt.assert_raises(
ValueError, BootDirectionGetter, data, tensor_model, 60, hsph_updated, 6, 20
ValueError,
BootDirectionGetter,
data,
tensor_model,
60,
sphere=hsph_updated,
max_attempts=6,
b_tol=20,
)
npt.assert_raises(
ValueError, BootDirectionGetter, data, tensor_model, 60, hsph_updated, 6, -1
ValueError,
BootDirectionGetter,
data,
tensor_model,
60,
sphere=hsph_updated,
max_attempts=6,
b_tol=-1,
)
Loading

0 comments on commit 9892997

Please sign in to comment.