-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brute force method to find the optimal number of harmonics for WaveX
, DMWaveX
, and CMWaveX
#1824
Open
abhisrkckl
wants to merge
26
commits into
nanograv:master
Choose a base branch
from
abhisrkckl:select-nharms
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
2156329
noise_analysis
abhisrkckl c4fceda
test_noise_analysis
abhisrkckl 613d0e5
tests
abhisrkckl fdd12ce
joblib
abhisrkckl e14b82a
cmt
abhisrkckl 3f8c062
Merge branch 'nanograv:master' into select-nharms
abhisrkckl 3429bae
Merge branch 'cmwavex' into select-nharms
abhisrkckl 05af781
rn
abhisrkckl 5e5eedb
uncmt
abhisrkckl 97892a6
test_wx2pl
abhisrkckl 6428c47
CHANGELOG
abhisrkckl d96dcc1
test_optimal_nharms
abhisrkckl 5891c81
rednoise-fit-example.py
abhisrkckl a32bb40
CHANGELOG
abhisrkckl 8c45640
CHANGELOG
abhisrkckl 93728a1
msg
abhisrkckl 2723932
docstrings
abhisrkckl dc42b8b
fix test
abhisrkckl 8f43826
Merge branch 'nanograv:master' into select-nharms
abhisrkckl 5276f2d
Merge branch 'cmwavex' into select-nharms
abhisrkckl e4287c8
Merge branch 'nanograv:master' into select-nharms
abhisrkckl 3ee4cad
Merge branch 'nanograv:master' into select-nharms
abhisrkckl 7e79406
Merge branch 'master' into select-nharms
abhisrkckl e75dada
Merge branch 'nanograv:master' into select-nharms
abhisrkckl d3ca68e
Merge branch 'nanograv:master' into select-nharms
abhisrkckl 008be2d
Merge branch 'master' into select-nharms
abhisrkckl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ uncertainties | |
loguru | ||
nestle>=0.2.0 | ||
numdifftools | ||
joblib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
from copy import deepcopy | ||
from typing import List, Optional, Tuple | ||
from itertools import product as cartesian_product | ||
|
||
from joblib import Parallel, cpu_count, delayed | ||
import numpy as np | ||
from astropy import units as u | ||
|
||
from pint.models.chromatic_model import ChromaticCM | ||
from pint.models.dispersion_model import DispersionDM | ||
from pint.models.phase_offset import PhaseOffset | ||
from pint.models.timing_model import TimingModel | ||
from pint.toa import TOAs | ||
from pint.logging import setup as setup_log | ||
from pint.utils import ( | ||
akaike_information_criterion, | ||
cmwavex_setup, | ||
dmwavex_setup, | ||
wavex_setup, | ||
) | ||
|
||
|
||
def find_optimal_nharms( | ||
model: TimingModel, | ||
toas: TOAs, | ||
include_components: List[str] = ["WaveX", "DMWaveX", "CMWaveX"], | ||
nharms_max: int = 45, | ||
chromatic_index: float = 4, | ||
num_parallel_jobs: Optional[int] = None, | ||
) -> Tuple[tuple, np.ndarray]: | ||
"""Find the optimal number of harmonics for `WaveX`/`DMWaveX`/`CMWaveX` using the | ||
Akaike Information Criterion. | ||
|
||
This function runs a brute force search over a grid of harmonic numbers, from 0 to | ||
`nharms_max`. This is executed in multiple processes using the `joblib` library the | ||
number of processes is controlled through the `num_parallel_jobs` argument. | ||
|
||
Please note that the execution time scales as `O(nharms_max**len(include_components))`, | ||
which can quickly become large. Hence, if you are using large values of `nharms_max`, it | ||
is recommended that this be run on a cluster with a large number of CPUs. | ||
|
||
Parameters | ||
---------- | ||
model: `pint.models.timing_model.TimingModel` | ||
The timing model. Should not already contain `WaveX`/`DMWaveX` or `PLRedNoise`/`PLDMNoise`. | ||
toas: `pint.toa.TOAs` | ||
Input TOAs | ||
component: list[str] | ||
Component names; a non-empty sublist of ["WaveX", "DMWaveX", "CMWaveX"] | ||
nharms_max: int, optional | ||
Maximum number of harmonics (default is 45) for each component | ||
chromatic_index: float | ||
Chromatic index for `CMWaveX` | ||
num_parallel_jobs: int, optional | ||
Number of parallel processes. The default is the number of available CPU cores. | ||
|
||
Returns | ||
------- | ||
aics: ndarray | ||
Array of AIC values. | ||
nharms_opt: tuple | ||
Optimal numbers of harmonics | ||
""" | ||
assert len(set(include_components).intersection(set(model.components.keys()))) == 0 | ||
assert len(include_components) > 0 | ||
|
||
idxs = list( | ||
cartesian_product( | ||
*np.repeat([np.arange(nharms_max + 1)], len(include_components), axis=0) | ||
) | ||
) | ||
|
||
if num_parallel_jobs is None: | ||
num_parallel_jobs = cpu_count() | ||
|
||
aics_flat = Parallel(n_jobs=num_parallel_jobs, verbose=13)( | ||
delayed( | ||
lambda ii: compute_aic(model, toas, include_components, ii, chromatic_index) | ||
)(ii) | ||
for ii in idxs | ||
) | ||
|
||
aics = np.reshape(aics_flat, [nharms_max + 1] * len(include_components)) | ||
|
||
assert np.isfinite(aics).all(), "Infs/NaNs found in AICs!" | ||
|
||
return aics, np.unravel_index(np.argmin(aics), aics.shape) | ||
|
||
|
||
def compute_aic( | ||
model: TimingModel, | ||
toas: TOAs, | ||
include_components: List[str], | ||
nharms: np.ndarray, | ||
chromatic_index: float, | ||
): | ||
"""Given a pre-fit model and TOAs, add the `[CM|DM]WaveX` components to the model, | ||
fit the model to the TOAs, and compute the Akaike Information criterion using the | ||
post-fit timing model. | ||
|
||
Parameters | ||
---------- | ||
model: `pint.models.timing_model.TimingModel` | ||
The pre-fit timing model. Should not already contain `WaveX`/`DMWaveX` or `PLRedNoise`/`PLDMNoise`. | ||
toas: `pint.toa.TOAs` | ||
Input TOAs | ||
component: list[str] | ||
Component names; a non-empty sublist of ["WaveX", "DMWaveX", "CMWaveX"] | ||
nharms: ndarray | ||
The number of harmonics for each component | ||
chromatic_index: float | ||
Chromatic index for `CMWaveX` | ||
|
||
Returns | ||
------- | ||
aic: float | ||
The AIC value. | ||
""" | ||
setup_log(level="WARNING") | ||
|
||
model1 = prepare_model( | ||
model, toas.get_Tspan(), include_components, nharms, chromatic_index | ||
) | ||
|
||
from pint.fitter import Fitter | ||
|
||
# Downhill fitters don't work well here. | ||
# TODO: Investigate this. | ||
ftr = Fitter.auto(toas, model1, downhill=False) | ||
ftr.fit_toas(maxiter=10) | ||
|
||
return akaike_information_criterion(ftr.model, toas) | ||
|
||
|
||
def prepare_model( | ||
model: TimingModel, | ||
Tspan: u.Quantity, | ||
include_components: List[str], | ||
nharms: np.ndarray, | ||
chromatic_index: float, | ||
): | ||
"""Given a pre-fit model and TOAs, add the `[CM|DM]WaveX` components to the model. Also sets parameters like | ||
`PHOFF` and `DM` and `CM` derivatives as free. | ||
|
||
Parameters | ||
---------- | ||
model: `pint.models.timing_model.TimingModel` | ||
The pre-fit timing model. Should not already contain `WaveX`/`DMWaveX` or `PLRedNoise`/`PLDMNoise`. | ||
Tspan: u.Quantity | ||
The observation time span | ||
component: list[str] | ||
Component names; a non-empty sublist of ["WaveX", "DMWaveX", "CMWaveX"] | ||
nharms: ndarray | ||
The number of harmonics for each component | ||
chromatic_index: float | ||
Chromatic index for `CMWaveX` | ||
|
||
Returns | ||
------- | ||
aic: float | ||
The AIC value. | ||
""" | ||
|
||
model1 = deepcopy(model) | ||
|
||
for comp in ["PLRedNoise", "PLDMNoise", "PLCMNoise"]: | ||
if comp in model1.components: | ||
model1.remove_component(comp) | ||
|
||
if "PhaseOffset" not in model1.components: | ||
model1.add_component(PhaseOffset()) | ||
model1.PHOFF.frozen = False | ||
|
||
for jj, comp in enumerate(include_components): | ||
if comp == "WaveX": | ||
nharms_wx = nharms[jj] | ||
if nharms_wx > 0: | ||
wavex_setup(model1, Tspan, n_freqs=nharms_wx, freeze_params=False) | ||
elif comp == "DMWaveX": | ||
nharms_dwx = nharms[jj] | ||
if nharms_dwx > 0: | ||
if "DispersionDM" not in model1.components: | ||
model1.add_component(DispersionDM()) | ||
|
||
model1["DM"].frozen = False | ||
|
||
if model1["DM1"].quantity is None: | ||
model1["DM1"].quantity = 0 * model1["DM1"].units | ||
model1["DM1"].frozen = False | ||
|
||
if "DM2" not in model1.params: | ||
model1.components["DispersionDM"].add_param( | ||
model["DM1"].new_param(2) | ||
) | ||
if model1["DM2"].quantity is None: | ||
model1["DM2"].quantity = 0 * model1["DM2"].units | ||
model1["DM2"].frozen = False | ||
|
||
if model1["DMEPOCH"].quantity is None: | ||
model1["DMEPOCH"].quantity = model1["PEPOCH"].quantity | ||
|
||
dmwavex_setup(model1, Tspan, n_freqs=nharms_dwx, freeze_params=False) | ||
elif comp == "CMWaveX": | ||
nharms_cwx = nharms[jj] | ||
if nharms_cwx > 0: | ||
if "ChromaticCM" not in model1.components: | ||
model1.add_component(ChromaticCM()) | ||
model1["TNCHROMIDX"].value = chromatic_index | ||
|
||
model1["CM"].frozen = False | ||
if model1["CM1"].quantity is None: | ||
model1["CM1"].quantity = 0 * model1["CM1"].units | ||
model1["CM1"].frozen = False | ||
|
||
if "CM2" not in model1.params: | ||
model1.components["ChromaticCM"].add_param( | ||
model1["CM1"].new_param(2) | ||
) | ||
if model1["CM2"].quantity is None: | ||
model1["CM2"].quantity = 0 * model1["CM2"].units | ||
model1["CM2"].frozen = False | ||
|
||
if model1["CMEPOCH"].quantity is None: | ||
model1["CMEPOCH"].quantity = model1["PEPOCH"].quantity | ||
|
||
cmwavex_setup(model1, Tspan, n_freqs=nharms_cwx, freeze_params=False) | ||
else: | ||
raise ValueError(f"Unsupported component {comp}.") | ||
|
||
return model1 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we discuss having a new requirement more broadly? If this is only needed for a subset of tasks, should it be optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally fine with adding new requirements if the requirement is pure python and doesn't have a bunch of its own new requirements.
nestle
seems like it is of that (good) type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was actually adding
joblib
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In particular, I wonder about how using
joblib
compares to the use ofconcurrent.futures
? I see some discussion online. It seems like we might want to stick with one library for that functionality.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try
concurrent.futures
.