Skip to content

Commit

Permalink
Merge branch 'type-hints' into wbfit
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Jan 21, 2025
2 parents d1f7bda + 9a0db16 commit 6b7c10b
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions src/pint/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,7 +2491,13 @@ def scaled_all_sigma(self):
scaled_sigmas_no_unit.append(scaled_sigma)
return np.hstack(scaled_sigmas_no_unit)

def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
def fit_toas(
self,
maxiter: int = 1,
threshold: float = 0.0,
full_cov: bool = False,
debug: bool = False,
):
"""Carry out a generalized least-squares fitting procedure.
The algorithm here is essentially the same as used in
Expand Down Expand Up @@ -2671,15 +2677,15 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
class LMFitter(Fitter):
def fit_toas(
self,
maxiter=50,
maxiter: int = 50,
*,
min_chi2_decrease=1e-3,
lambda_factor_decrease=2,
lambda_factor_increase=3,
lambda_factor_invalid=10,
threshold=1e-14,
min_lambda=0.5,
debug=False,
min_chi2_decrease: float = 1e-3,
lambda_factor_decrease: float = 2.0,
lambda_factor_increase: float = 3.0,
lambda_factor_invalid: float = 10.0,
threshold: float = 1e-14,
min_lambda: float = 0.5,
debug: bool = False,
):
current_state = self.create_state()
try:
Expand Down Expand Up @@ -2800,7 +2806,14 @@ class WidebandLMFitter(LMFitter):
Unfortunately it doesn't.
"""

def __init__(self, toas, model, track_mode=None, residuals=None, add_args=None):
def __init__(
self,
toas: TOAs,
model: TimingModel,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals: WidebandTOAResiduals = None,
add_args: Optional[dict] = None,
):
self.method = "downhill_wideband"
self.full_cov = False
self.threshold = 0
Expand All @@ -2810,9 +2823,12 @@ def __init__(self, toas, model, track_mode=None, residuals=None, add_args=None):
)
self.is_wideband = True

self.model: TimingModel
self.toas: TOAs
self.resids: WidebandTOAResiduals
self.parameter_covariance_matrix: CovarianceMatrix

def make_resids(self, model):
def make_resids(self, model: TimingModel):
return WidebandTOAResiduals(
self.toas,
model,
Expand All @@ -2825,12 +2841,14 @@ def create_state(self):
self, self.model, full_cov=self.full_cov, threshold=self.threshold
)

def fit_toas(self, maxiter=50, full_cov=False, debug=False, **kwargs):
def fit_toas(
self, maxiter: int = 50, full_cov: bool = False, debug: bool = False, **kwargs
):
self.full_cov = full_cov
# FIXME: set up noise residuals et cetera
return super().fit_toas(maxiter=maxiter, debug=debug, **kwargs)

def update_from_state(self, state, debug=False):
def update_from_state(self, state: WidebandState, debug: bool = False):
# Nicer not to keep this if we have a choice, it introduces reference cycles
self.current_state = state
self.model = state.model
Expand Down

0 comments on commit 6b7c10b

Please sign in to comment.