Skip to content

Commit

Permalink
Merge branch 'type-hints' into wbfit
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Jan 21, 2025
2 parents 6b7c10b + a622afa commit 1851ffb
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 81 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG-unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ the released changes.
- `noise_resids` and `dm_noise_resids` properties in `WidebandTOAResiduals`
- `calc_combined_resids`, `calc_whitened_resids`, and `calc_whitened_dm_resids` methods in `WidebandTOAResiduals`
- Type hints in `pint.models.timing_model`
- Type hints in `pint.fitter`
### Fixed
- Made `TimingModel.is_binary()` more robust.
### Removed
- Unnecessary definition of `cached_property` from `pint.fitter` (Python 3.8 no longer needs to be supported).


153 changes: 73 additions & 80 deletions src/pint/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,19 @@

import contextlib
import copy
from typing import Dict, Iterable, List, Literal, Optional, OrderedDict, Tuple, Union
from typing import (
Any,
Dict,
Iterable,
List,
Literal,
Optional,
OrderedDict,
Tuple,
Union,
)
from warnings import warn
from functools import cached_property

import astropy.units as u
import numpy as np
Expand Down Expand Up @@ -118,64 +129,6 @@
"MaxiterReached",
]

try:
from functools import cached_property
except ImportError:
# not supported in python 3.7
# This is just the code from python 3.8
from _thread import RLock

_NOT_FOUND = object()

class cached_property:
def __init__(self, func):
self.func = func
self.attrname = None
self.__doc__ = func.__doc__
self.lock = RLock()

def __set_name__(self, owner, name):
if self.attrname is None:
self.attrname = name
elif name != self.attrname:
raise TypeError(
"Cannot assign the same cached_property to two different names "
f"({self.attrname!r} and {name!r})."
)

def __get__(self, instance, owner=None):
if instance is None:
return self
if self.attrname is None:
raise TypeError(
"Cannot use cached_property instance without calling __set_name__ on it."
)
try:
cache = instance.__dict__
except AttributeError:
# not all objects have __dict__ (e.g. class defines slots)
msg = (
f"No '__dict__' attribute on {type(instance).__name__!r} "
f"instance to cache {self.attrname!r} property."
)
raise TypeError(msg) from None
val = cache.get(self.attrname, _NOT_FOUND)
if val is _NOT_FOUND:
with self.lock:
# check if another thread filled cache while we awaited lock
val = cache.get(self.attrname, _NOT_FOUND)
if val is _NOT_FOUND:
val = self.func(instance)
try:
cache[self.attrname] = val
except TypeError:
msg = (
f"The '__dict__' attribute on {type(instance).__name__!r} instance "
f"does not support item assignment for caching {self.attrname!r} property."
)
raise TypeError(msg) from None
return val


class Fitter:
"""Base class for objects encapsulating fitting problems.
Expand Down Expand Up @@ -375,6 +328,7 @@ def get_summary(self, nodmx: bool = False) -> str:
# First, print fit quality metrics
s = f"Fitted model using {self.method} method with {len(self.model.free_params)} free parameters to {self.toas.ntoas} TOAs\n"
if is_wideband:
self.resids_init: WidebandTOAResiduals
s += f"Prefit TOA residuals Wrms = {self.resids_init.toa.rms_weighted()}, Postfit TOA residuals Wrms = {self.resids.toa.rms_weighted()}\n"
s += f"Prefit DM residuals Wrms = {self.resids_init.dm.rms_weighted()}, Postfit DM residuals Wrms = {self.resids.dm.rms_weighted()}\n"
else:
Expand Down Expand Up @@ -536,7 +490,7 @@ def plot(self) -> None:
ax.grid(True)
plt.show()

def update_model(self, chi2: Optional[float] = None):
def update_model(self, chi2: Optional[float] = None) -> None:
"""Update the model to reflect fit results and TOA properties.
This is called by ``fit_toas`` to ensure that parameters like
Expand Down Expand Up @@ -636,7 +590,7 @@ def ftest(
remove: bool = False,
full_output: bool = False,
maxiter: int = 1,
) -> dict:
) -> Dict[str, Any]:
"""Compare the significance of adding/removing parameters to a timing model.
Parameters
Expand Down Expand Up @@ -835,7 +789,7 @@ def get_params_dict(
"""
return self.model.get_params_dict(which=which, kind=kind)

def set_fitparams(self, *params):
def set_fitparams(self, *params) -> None:
"""Update the "frozen" attribute of model parameters. Deprecated."""
warn(
"This function is confusing and deprecated. Set self.model.free_params instead.",
Expand Down Expand Up @@ -903,7 +857,7 @@ def set_param_uncertainties(self, fitp: Dict[str, float]) -> None:
self.model.set_param_uncertainties(fitp)

@property
def covariance_matrix(self) -> CovarianceMatrix:
def covariance_matrix(self):
warn(
"This parameter is deprecated. Use `parameter_covariance_matrix` instead of `covariance_matrix`",
category=DeprecationWarning,
Expand All @@ -925,6 +879,9 @@ def __init__(self, fitter: Fitter, model: TimingModel):
self.fitter = fitter
self.model = model

self.params: List[str]
self.fac: np.ndarray

@cached_property
def resids(self) -> Residuals:
try:
Expand All @@ -945,7 +902,7 @@ def step(self):
raise NotImplementedError

@cached_property
def parameter_covariance_matrix(self):
def parameter_covariance_matrix(self) -> CovarianceMatrix:
raise NotImplementedError

@property
Expand All @@ -960,7 +917,7 @@ def predicted_chi2(self, step, lambda_):
"""Predict the chi2 after taking a step based on the linear approximation"""
raise NotImplementedError

def take_step_model(self, step, lambda_=1):
def take_step_model(self, step, lambda_=1) -> TimingModel:
"""Make a new model reflecting the new parameters."""
# log.debug(f"Taking step {lambda_} * {list(zip(self.params, step))}")
new_model = copy.deepcopy(self.model)
Expand All @@ -980,7 +937,7 @@ def take_step_model(self, step, lambda_=1):
log.warning(f"Unexpected parameter {p}")
return new_model

def take_step(self, step, lambda_):
def take_step(self, step: np.ndarray, lambda_: float) -> "ModelState":
"""Return a new state moved by lambda_*step."""
raise NotImplementedError

Expand Down Expand Up @@ -1008,6 +965,13 @@ def __init__(
)
self.method = "downhill_checked"

self.current_state: ModelState

def create_state() -> ModelState:
# Subclasses will override this.
# I am adding this here just to improve code highlighting.
raise NotImplementedError

def _fit_toas(
self,
maxiter: int = 20,
Expand Down Expand Up @@ -1101,7 +1065,7 @@ def _fit_toas(
# I don't know why this fails with multiprocessing, but bypass if it does
with contextlib.suppress(ValueError):
log.trace(f"Setting {getattr(self.model, p)} uncertainty to {e}")
pm = getattr(self.model, p)
pm = self.model[p]
except AttributeError:
if p != "Offset":
log.warning(f"Unexpected parameter {p}")
Expand Down Expand Up @@ -1297,13 +1261,15 @@ def __init__(
self.threshold = threshold

@cached_property
def step(self):
def step(self) -> np.ndarray:
# Define the linear system
M, params, units = self.model.designmatrix(
toas=self.fitter.toas, incfrozen=False, incoffset=True
)
# Get residuals and TOA uncertainties in seconds
Nvec = self.model.scaled_toa_uncertainty(self.fitter.toas).to(u.s).value
Nvec: np.ndarray = (
self.model.scaled_toa_uncertainty(self.fitter.toas).to(u.s).value
)
scaled_resids = self.resids.time_resids.to(u.s).value / Nvec

# "Whiten" design matrix and residuals by dividing by uncertainties
Expand Down Expand Up @@ -1376,13 +1342,13 @@ def step(self):
# Scaling by fac recovers original units
return (Vt.T @ ((U.T @ scaled_resids) / s)) / fac

def take_step(self, step, lambda_=1):
def take_step(self, step: np.ndarray, lambda_: float = 1.0) -> "WLSState":
return WLSState(
self.fitter, self.take_step_model(step, lambda_), threshold=self.threshold
)

@cached_property
def parameter_covariance_matrix(self):
def parameter_covariance_matrix(self) -> CovarianceMatrix:
# make sure we compute the SVD
self.step
# Sigma = np.dot(Vt.T / s, U.T)
Expand All @@ -1401,15 +1367,27 @@ class DownhillWLSFitter(DownhillFitter):
or :class:`pint.fitter.DownhillFitter`.
"""

def __init__(self, toas, model, track_mode=None, residuals=None):
def __init__(
self,
toas: TOAs,
model: TimingModel,
track_mode: Optional[Literal["use_pulse_numbers", "nearest"]] = None,
residuals: Optional[Residuals] = None,
):
if model.has_correlated_errors:
raise CorrelatedErrors(model)
super().__init__(
toas=toas, model=model, residuals=residuals, track_mode=track_mode
)
self.method = "downhill_wls"

def fit_toas(self, maxiter=10, threshold=None, debug=False, **kwargs):
def fit_toas(
self,
maxiter: int = 10,
threshold: Optional[float] = None,
debug: bool = False,
**kwargs,
):
"""Fit TOAs.
This is mostly implemented in
Expand All @@ -1429,18 +1407,24 @@ def fit_toas(self, maxiter=10, threshold=None, debug=False, **kwargs):
self.threshold = threshold
super().fit_toas(maxiter=maxiter, debug=debug, **kwargs)

def create_state(self):
def create_state(self) -> WLSState:
return WLSState(self, self.model)


class GLSState(ModelState):
def __init__(self, fitter, model, full_cov=False, threshold=None):
def __init__(
self,
fitter: Fitter,
model: TimingModel,
full_cov: bool = False,
threshold: Optional[float] = None,
):
super().__init__(fitter, model)
self.threshold = threshold
self.full_cov = full_cov

@cached_property
def step(self):
def step(self) -> np.ndarray:
# Define the linear system
M, params, units = self.model.designmatrix(
toas=self.fitter.toas, incfrozen=False, incoffset=True
Expand Down Expand Up @@ -1524,7 +1508,7 @@ def step(self):
# compute absolute estimates, normalized errors, covariance matrix
return xhat / norm

def take_step(self, step, lambda_=1):
def take_step(self, step: np.ndarray, lambda_: float = 1.0) -> "GLSState":
return GLSState(
self.fitter,
self.take_step_model(step, lambda_),
Expand All @@ -1533,7 +1517,7 @@ def take_step(self, step, lambda_=1):
)

@cached_property
def parameter_covariance_matrix(self):
def parameter_covariance_matrix(self) -> CovarianceMatrix:
# make sure we compute the SVD
self.step
xvar = np.dot(self.Vt.T / self.s, self.Vt)
Expand Down Expand Up @@ -1570,12 +1554,21 @@ def __init__(
self.full_cov = False
self.threshold = 0

def create_state(self):
self.current_state: GLSState

def create_state(self) -> GLSState:
return GLSState(
self, self.model, full_cov=self.full_cov, threshold=self.threshold
)

def fit_toas(self, maxiter=10, threshold=0, full_cov=False, debug=False, **kwargs):
def fit_toas(
self,
maxiter: int = 10,
threshold: float = 0.0,
full_cov: bool = False,
debug: bool = False,
**kwargs,
):
"""Fit TOAs.
This is mostly implemented in
Expand Down
4 changes: 3 additions & 1 deletion src/pint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,7 +2864,9 @@ def get_unit(parname: str) -> u.Unit:
return ac.param_to_unit(parname)


def normalize_designmatrix(M, params):
def normalize_designmatrix(
M: np.ndarray, params: List[str]
) -> Tuple[np.ndarray, np.ndarray]:
"""Normalize each row of the design matrix.
This is used while computing the GLS chi2 and the GLS fitting step. The
Expand Down

0 comments on commit 1851ffb

Please sign in to comment.