Skip to content

Commit

Permalink
Merge pull request dipy#3257 from deka27/keyword-reconst
Browse files Browse the repository at this point in the history
NF: Applying Decorators in Module (Reconst)
  • Loading branch information
skoudoro authored Sep 4, 2024
2 parents 9892997 + d26f769 commit 2ea0906
Show file tree
Hide file tree
Showing 42 changed files with 835 additions and 465 deletions.
8 changes: 6 additions & 2 deletions dipy/direction/peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,11 @@ def peaks_from_model(
"""
if return_sh and (B is None or invB is None):
B, invB = sh_to_sf_matrix(
sphere, sh_order_max, sh_basis_type, return_inv=True, legacy=legacy
sphere,
sh_order_max=sh_order_max,
basis_type=sh_basis_type,
return_inv=True,
legacy=legacy,
)

num_processes = determine_num_processes(num_processes)
Expand Down Expand Up @@ -597,7 +601,7 @@ def peaks_from_model(
if not mask[idx]:
continue

odf = model.fit(data[idx]).odf(sphere)
odf = model.fit(data[idx]).odf(sphere=sphere)

if return_sh:
shm_coeff[idx] = np.dot(odf, invB)
Expand Down
5 changes: 4 additions & 1 deletion dipy/reconst/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"""

from dipy.testing.decorators import warning_for_keywords


class ReconstModel:
"""Abstract class for signal reconstruction models"""
Expand All @@ -25,7 +27,8 @@ def __init__(self, gtab):
"""
self.gtab = gtab

def fit(self, data, mask=None, **kwargs):
@warning_for_keywords()
def fit(self, data, *, mask=None, **kwargs):
return ReconstFit(self, data)


Expand Down
4 changes: 3 additions & 1 deletion dipy/reconst/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dipy.core.onetime import auto_attr
from dipy.testing.decorators import warning_for_keywords


class Cache:
Expand Down Expand Up @@ -65,7 +66,8 @@ def cache_set(self, tag, key, value):
"""
self._cache[(tag, key)] = value

def cache_get(self, tag, key, default=None):
@warning_for_keywords()
def cache_get(self, tag, key, *, default=None):
"""Retrieve a value from the cache.
Parameters
Expand Down
6 changes: 4 additions & 2 deletions dipy/reconst/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import numpy as np

import dipy.core.gradients as gt
from dipy.testing.decorators import warning_for_keywords


def coeff_of_determination(data, model, axis=-1):
@warning_for_keywords()
def coeff_of_determination(data, model, *, axis=-1):
r"""Calculate the coefficient of determination for a model prediction,
relative to data.
Expand Down Expand Up @@ -145,7 +147,7 @@ class for which prediction is conducted. That is, the Fit object that gets
err_str += "do not have an implementation of model prediction"
err_str += " and do not support cross-validation"
raise ValueError(err_str)
this_predict = S0[..., None] * this_fit.predict(left_out_gtab, S0=1)
this_predict = S0[..., None] * this_fit.predict(gtab=left_out_gtab, S0=1)

idx_to_assign = np.where(~gtab.b0s_mask)[0][~fold_mask]
prediction[..., idx_to_assign] = this_predict[..., np.sum(gtab.b0s_mask) :]
Expand Down
41 changes: 30 additions & 11 deletions dipy/reconst/csdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from dipy.reconst.utils import _mask_from_roi, _roi_in_volume
from dipy.sims.voxel import single_tensor
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecate_with_version, deprecated_params


Expand All @@ -36,9 +37,11 @@
since="1.2",
until="1.4",
)
@warning_for_keywords()
def auto_response(
gtab,
data,
*,
roi_center=None,
roi_radius=10,
fa_thr=0.7,
Expand Down Expand Up @@ -161,7 +164,8 @@ class AxSymShResponse:
"""

def __init__(self, S0, dwi_response, bvalue=None):
@warning_for_keywords()
def __init__(self, S0, dwi_response, *, bvalue=None):
self.S0 = S0
self.dwi_response = dwi_response
self.bvalue = bvalue
Expand All @@ -183,10 +187,12 @@ def on_sphere(self, sphere):

class ConstrainedSphericalDeconvModel(SphHarmModel):
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def __init__(
self,
gtab,
response,
*,
reg_sphere=None,
sh_order_max=8,
lambda_=1,
Expand Down Expand Up @@ -310,13 +316,14 @@ def fit(self, data, **kwargs):
dwi_data,
self._X,
self.B_reg,
self.tau,
tau=self.tau,
convergence=self.convergence,
P=self._P,
)
return SphHarmFit(self, shm_coeff, None)

def predict(self, sh_coeff, gtab=None, S0=1.0):
@warning_for_keywords()
def predict(self, sh_coeff, *, gtab=None, S0=1.0):
"""Compute a signal prediction given spherical harmonic coefficients
for the provided GradientTable class instance.
Expand Down Expand Up @@ -362,8 +369,9 @@ def predict(self, sh_coeff, gtab=None, S0=1.0):

class ConstrainedSDTModel(SphHarmModel):
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def __init__(
self, gtab, ratio, reg_sphere=None, sh_order_max=8, lambda_=1.0, tau=0.1
self, gtab, ratio, *, reg_sphere=None, sh_order_max=8, lambda_=1.0, tau=0.1
):
r"""Spherical Deconvolution Transform (SDT)
:footcite:p:`Descoteaux2009`.
Expand Down Expand Up @@ -457,7 +465,7 @@ def fit(self, data, **kwargs):
# normalize ODF
odf_sh /= Z
shm_coeff, num_it = odf_deconv(
odf_sh, self.R, self.B_reg, self.lambda_, self.tau
odf_sh, self.R, self.B_reg, lambda_=self.lambda_, tau=self.tau
)
# print 'SDT CSD converged after %d iterations' % num_it

Expand Down Expand Up @@ -487,7 +495,8 @@ def estimate_response(gtab, evals, S0):


@deprecated_params("n", new_name="l_values", since="1.9", until="2.0")
def forward_sdt_deconv_mat(ratio, l_values, r2_term=False):
@warning_for_keywords()
def forward_sdt_deconv_mat(ratio, l_values, *, r2_term=False):
r"""Build forward sharpening deconvolution transform (SDT) matrix
Parameters
Expand Down Expand Up @@ -569,7 +578,8 @@ def _solve_cholesky(Q, z):
return f


def csdeconv(dwsignal, X, B_reg, tau=0.1, convergence=50, P=None):
@warning_for_keywords()
def csdeconv(dwsignal, X, B_reg, *, tau=0.1, convergence=50, P=None):
r"""Constrained-regularized spherical deconvolution (CSD).
Deconvolves the axially symmetric single fiber response function `r_rh` in
Expand Down Expand Up @@ -737,7 +747,8 @@ def csdeconv(dwsignal, X, B_reg, tau=0.1, convergence=50, P=None):
return fodf_sh, _num_it


def odf_deconv(odf_sh, R, B_reg, lambda_=1.0, tau=0.1, r2_term=False):
@warning_for_keywords()
def odf_deconv(odf_sh, R, B_reg, *, lambda_=1.0, tau=0.1, r2_term=False):
r"""ODF constrained-regularized spherical deconvolution using
the Sharpening Deconvolution Transform (SDT).
Expand Down Expand Up @@ -840,9 +851,11 @@ def odf_deconv(odf_sh, R, B_reg, lambda_=1.0, tau=0.1, r2_term=False):


@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def odf_sh_to_sharp(
odfs_sh,
sphere,
*,
basis=None,
ratio=3 / 15.0,
sh_order_max=8,
Expand Down Expand Up @@ -925,7 +938,8 @@ def odf_sh_to_sharp(
return fodf_sh


def mask_for_response_ssst(gtab, data, roi_center=None, roi_radii=10, fa_thr=0.7):
@warning_for_keywords()
def mask_for_response_ssst(gtab, data, *, roi_center=None, roi_radii=10, fa_thr=0.7):
"""Computation of mask for single-shell single-tissue (ssst) response
function using FA.
Expand Down Expand Up @@ -1055,7 +1069,8 @@ def response_from_mask_ssst(gtab, data, mask):
return _get_response(S0s, lambdas)


def auto_response_ssst(gtab, data, roi_center=None, roi_radii=10, fa_thr=0.7):
@warning_for_keywords()
def auto_response_ssst(gtab, data, *, roi_center=None, roi_radii=10, fa_thr=0.7):
"""Automatic estimation of single-shell single-tissue (ssst) response
function using FA.
Expand Down Expand Up @@ -1093,7 +1108,9 @@ def auto_response_ssst(gtab, data, roi_center=None, roi_radii=10, fa_thr=0.7):
`ratio` (more details are available in the description of the function).
"""

mask = mask_for_response_ssst(gtab, data, roi_center, roi_radii, fa_thr)
mask = mask_for_response_ssst(
gtab, data, roi_center=roi_center, roi_radii=roi_radii, fa_thr=fa_thr
)
response, ratio = response_from_mask_ssst(gtab, data, mask)

return response, ratio
Expand All @@ -1114,9 +1131,11 @@ def _get_response(S0s, lambdas):


@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def recursive_response(
gtab,
data,
*,
mask=None,
sh_order_max=8,
peak_thr=0.01,
Expand Down
25 changes: 16 additions & 9 deletions dipy/reconst/cti.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from dipy.reconst.multi_voxel import multi_voxel_fit
from dipy.reconst.utils import cti_design_matrix as design_matrix
from dipy.testing.decorators import warning_for_keywords


def from_qte_to_cti(C):
Expand Down Expand Up @@ -135,7 +136,8 @@ def split_cti_params(cti_params):
return evals, evecs, kt, ct


def cti_prediction(cti_params, gtab1, gtab2, S0=1):
@warning_for_keywords()
def cti_prediction(cti_params, gtab1, gtab2, *, S0=1):
"""Predict a signal given correlation tensor imaging parameters.
Parameters
Expand Down Expand Up @@ -198,7 +200,7 @@ def cti_prediction(cti_params, gtab1, gtab2, S0=1):
class CorrelationTensorModel(ReconstModel):
"""Class for the Correlation Tensor Model"""

def __init__(self, gtab1, gtab2, fit_method="WLS", *args, **kwargs):
def __init__(self, gtab1, gtab2, *args, fit_method="WLS", **kwargs):
"""Correlation Tensor Imaging Model.
See :footcite:p:`NetoHenriques2020` for further details about the model.
Expand Down Expand Up @@ -254,7 +256,8 @@ def __init__(self, gtab1, gtab2, fit_method="WLS", *args, **kwargs):
self.weights = fit_method in {"WLS", "WLLS", "UWLLS"}

@multi_voxel_fit
def fit(self, data, mask=None):
@warning_for_keywords()
def fit(self, data, *, mask=None):
"""Fit method of the CTI model class.
Parameters
Expand All @@ -278,7 +281,8 @@ def fit(self, data, mask=None):

return CorrelationTensorFit(self, params)

def predict(self, cti_params, S0=1):
@warning_for_keywords()
def predict(self, cti_params, *, S0=1):
"""Predict a signal for the CTI model class instance given parameters
Parameters
Expand Down Expand Up @@ -307,7 +311,7 @@ def predict(self, cti_params, S0=1):
Predicted signal based on the CTI model
"""

return cti_prediction(cti_params, self.gtab1, self.gtab2, S0)
return cti_prediction(cti_params, self.gtab1, self.gtab2, S0=S0)


class CorrelationTensorFit(DiffusionKurtosisFit):
Expand Down Expand Up @@ -341,7 +345,8 @@ def ct(self):
"""
return self.model_params[..., 27:48]

def predict(self, gtab1, gtab2, S0=1):
@warning_for_keywords()
def predict(self, gtab1, gtab2, *, S0=1):
"""Given a CTI model fit, predict the signal on the vertices of a
gradient table
Expand All @@ -360,7 +365,7 @@ def predict(self, gtab1, gtab2, S0=1):
S : numpy.ndarray
Predicted signal based on the CTI model
"""
return cti_prediction(self.model_params, gtab1, gtab2, S0)
return cti_prediction(self.model_params, gtab1, gtab2, S0=S0)

@property
def K_aniso(self):
Expand Down Expand Up @@ -515,7 +520,8 @@ def K_micro(self):
return micro_K


def params_to_cti_params(result, min_diffusivity=0):
@warning_for_keywords()
def params_to_cti_params(result, *, min_diffusivity=0):
# Extracting the diffusion tensor parameters from solution
DT_elements = result[:6]
evals, evecs = decompose_tensor(
Expand All @@ -536,8 +542,9 @@ def params_to_cti_params(result, min_diffusivity=0):
return cti_params


@warning_for_keywords()
def ls_fit_cti(
design_matrix, data, inverse_design_matrix, weights=True, min_diffusivity=0
design_matrix, data, inverse_design_matrix, *, weights=True, min_diffusivity=0
):
r"""Compute the diffusion kurtosis and covariance tensors using an
ordinary or weighted linear least squares approach
Expand Down
Loading

0 comments on commit 2ea0906

Please sign in to comment.