Skip to content

Commit

Permalink
Add more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
bencebecsy committed May 24, 2023
1 parent bfa7b30 commit b565a71
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 43 deletions.
82 changes: 45 additions & 37 deletions QuickCW/QuickMCMCUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,42 +380,42 @@ def get_param_names(pta):
class ChainParams():
"""store basic parameters the govern the evolution of the mcmc chain
:param T_max:
:param n_chain:
:param n_block_status_update:
:param n_int_block:
:param n_update_fisher:
:param save_every_n:
:param fisher_eig_downsample:
:param T_ladder:
:param includeCW:
:param prior_recovery:
:param verbosity:
:param freq_bounds:
:param gwb_comps:
:param cos_gwtheta_bounds:
:param gwphi_bounds:
:param de_history_size:
:param thin_de:
:param log_fishers:
:param log_mean_likelihoods:
:param savefile:
:param thin:
:param samples_precision:
:param save_first_n_chains:
:param prior_draw_prob:
:param de_prob:
:param fisher_prob:
:param rn_emp_dist_file:
:param dist_jump_weight:
:param rn_jump_weight:
:param gwb_jump_weight:
:param common_jump_weight:
:param all_jump_weight:
:param fix_rn:
:param zero_rn:
:param fix_gwb:
:param zero_gwb:
:param T_max: Maximum temperature of PT ladder
:param n_chain: Number of PT chains
:param n_block_status_update: Number of blocks between status updates
:param n_int_block: Number of iterations in a block [1_000]
:param n_update_fisher: Number of iterations between Fisher updates [100_000]
:param save_every_n: Number of iterations between saving intermediate results (needs to be intiger multiple of n_int_block) [10_000]
:param fisher_eig_downsample: Multiplier for how much less to do more expensive updates to fisher eigendirections for red noise and common parameters compared to diagonal elements [10]
:param T_ladder: Temperature ladder; if None, geometrically spaced ladder is made with n_chain chains reaching T_max [None]
:param includeCW: If False, we are not including the CW in the likelihood (good for testing) [True]
:param prior_recovery: If True, likelihood is set to a constant (good for testing the prior recovery of the MCMC) [False]
:param verbosity: Parameter indicating how much info to print (higher value means more prints) [1]
:param freq_bounds: Lower and upper prior bounds on the GW frequency of the CW; np.nan lower bound is automatically turned into one over the observation time [[np.nan, 1.e-07]]
:param gwb_comps: Number of frequency components to model in the GWB [14]
:param cos_gwtheta_bounds: Prior bounds on the cosine of the GW theta sky location parameter (useful e.g. for targeted searches) [[-1,1]]
:param gwphi_bounds: Prior bounds on the the GW phi sky location parameter (useful e.g. for targeted searches) [[0,2*np.pi]]
:param de_history_size: Size of the differential evolution buffer
:param thin_de: How much to thin samples for the DE buffer
:param log_fishers: --
:param log_mean_likelihoods: --
:param savefile: File name to save the results to, if None, no results are saved [None]
:param thin: How much to thin the samples by for saving [100]
:param samples_precision: Precision to use for the saved samples [np.single]
:param save_first_n_chains: Number of PT chains to save [1]
:param prior_draw_prob: Probability of prior draws [0.1]
:param de_prob: Probability of DE jumps [0.6]
:param fisher_prob: Probability of fisher updates [0.3]
:param rn_emp_dist_file: Filename with empirical distribution to use for per psr RN, if None, do not do empirical distribution jumps [None]
:param dist_jump_weight: Weight if jumps changing pulsar distances [0.2]
:param rn_jump_weight: Weight of jumps changing RN parameters [0.3]
:param gwb_jump_weight: Weight of jumps changing GWB parameters [0.1]
:param common_jump_weight: Weight of jumps changing common CW shape parameters (sky location, frequency, chirp mass) [0.2]
:param all_jump_weight: Weight of jumps changing all parameters [0.2]
:param fix_rn: If True, we fix per psr RN parameters to the value it starts at [False]
:param zero_rn: If True, we fix per psr RN amplitude to a very low value effectively turning it off [False]
:param fix_gwb: If True, we fix GWB parameters to the value it starts at [False]
:param zero_gwb: If True, we fix GWB amplitude to a very low value effectively turning it off [False]
"""

def __init__(self, T_max: float, n_chain: int, n_block_status_update: int, n_int_block: int = 1000,
Expand Down Expand Up @@ -532,7 +532,15 @@ def __init__(self, T_max: float, n_chain: int, n_block_status_update: int, n_int


class MCMCChain():
"""store the miscellaneous objects needed to manage the mcmc chain"""
"""store the miscellaneous objects needed to manage the mcmc chain
:param chain_params: ChainParams object
:param psrs: List of enterprise pulsar objects
:param pta: enterprise PTA object
:param max_toa: Latest TOA in any pulsar in the array
:param noisedict: Noise dictionary
:param ti: Time after initialization got from time.perf_counter()
"""
def __init__(self,chain_params,psrs,pta,max_toa,noisedict,ti):
#set up fast likelihoods
self.chain_params = chain_params
Expand Down
77 changes: 71 additions & 6 deletions QuickCW/QuickMTHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
#version using multiple try mcmc (based on Table 6 of https://vixra.org/pdf/1712.0244v3.pdf)
#@profile
def do_intrinsic_update_mt(mcc, itrb):
"""do the intrinsic update using the multiple try mcmc algorithm"""
"""do the intrinsic update using the multiple try mcmc algorithm
:param mcc: MCMCChain onject
:param itrb: Index within saved values (as opposed to block index itri or overall index itrn)
:return mcc.FLI_swap: FastLikeInfo object
"""
Npsr = mcc.x0s[0].Npsr
Ts = mcc.chain_params.Ts
for j in range(mcc.n_chain):
Expand Down Expand Up @@ -514,7 +520,22 @@ def do_intrinsic_update_mt(mcc, itrb):


def do_mt_step(mcc,j,itrb,new_point,samples_current,FLI_mem_save,recompute_rn,log_proposal_ratio):
"""compute the multiple tries and chose a sample"""
"""compute the multiple tries and chose a sample
:param mcc: MCMCChain onject
:param j: Index of PT chain
:param itrb: Index within saved values (as opposed to block index itri or overall index itrn)
:param new_point: Proposed new point (with new shape parameters)
:param samples_current: Current point in parameter space
:param FLI_mem_save: Parts of FLI object saved to memory
:param recompute_rn: If True, recompute everything needed to go to new RN parameters
:param log_proposal_ratio: Log of the proposal ratio needed to calculate acceptance probability
:return log_acc_ratio: Log of acceptance probability
:return chosen_trial: Index of chosen trial
:return sample_choose: Parameters of the chosen trial
:return log_Ls[chosen_trial]: Log likelihood of the chosen trial
"""
Ts = mcc.chain_params.Ts

log_prior_old = CWFastPrior.get_lnprior(samples_current, mcc.FPI)
Expand Down Expand Up @@ -609,7 +630,19 @@ def do_mt_step(mcc,j,itrb,new_point,samples_current,FLI_mem_save,recompute_rn,lo

@njit(parallel=True)
def get_mt_weights(x0_extras, FLI_use, Ts, log_posterior_old,tries,log_prior_news):
"""Helper function to quickly return multiple tries and their likelihoods fo MTMCMC"""
"""Helper function to quickly return multiple tries and their likelihoods fo MTMCMC
:param x0_extras: List of extra CWInfo objects for parallelizing multiple try
:param FLI_use: FastLikeInfo object
:param Ts: List of PT temperatures
:param log_posterior_old: Log posterior at old parameters
:param tries: Parameters at a set of multiple tries for which we want to calculate the weights
:param log_prior_news: Log prior values at propose new points
:return mt_weights: Multiple try weights
:return log_Ls: Log likelihoods
:return log_mt_norm_shift: Amount to shift the multiple try weights (helps with using floating point precision efficiently)
"""
#NOTE isfinite does not work with fastmath enabled
#set up needed arrays
log_mt_weights = np.zeros(cm.n_multi_try)
Expand Down Expand Up @@ -644,7 +677,18 @@ def get_mt_weights(x0_extras, FLI_use, Ts, log_posterior_old,tries,log_prior_new

@njit()
def add_rn_eig_jump(scale_eig0,scale_eig1,new_point,rn_base,idx_rn,Npsr,all_eigs=False):
"""add a fisher eigenvalue jump to the red noise parameters in place"""
"""add a fisher eigenvalue jump to the red noise parameters in place
:param scale_eig0: Amount to scale jump in gamma values by
:param scale_eig1: Amount to scale in log10_A values by
:param new_point: Parameter values to add RN jump to
:param rn_base: RN values to jump from (usually justa slice of new_point)
:param idx_rn: Indices of new_point containing RN parameters
:param Npsr: Number of pulsars
:param all_eigs: If True, perturb all pulsars' RN, if False, pick randomly [False]
:return new_point: Perturbed parameter values
"""
which_eig = np.random.choice(2, size=Npsr)
jump_sizes = np.random.normal(0., 1.,Npsr)

Expand All @@ -662,7 +706,16 @@ def add_rn_eig_jump(scale_eig0,scale_eig1,new_point,rn_base,idx_rn,Npsr,all_eigs

@njit()
def set_params(sample_set,jumps,fisher_mask,random_draws_from_prior,x0):
"""assign parameters to tries for multiple try mcmc"""
"""assign parameters to tries for multiple try mcmc
:param sample_set: Samples to start from
:param jumps: Precaluclated fisher jumps to use
:param fisher_mask: Mask determining which projection parameters to do fisher jump vs prior draw in
:param random_draws_from_prior: Precalculated prior draws to use
:param x0: CWInfo object
:return ref_tries: 2D array holding samples at multiple trials
"""
ref_tries = np.zeros((cm.n_multi_try, sample_set.size))
#jumps and random_draws_from_prior should give a null jump for the 0th value

Expand All @@ -684,7 +737,19 @@ def set_params(sample_set,jumps,fisher_mask,random_draws_from_prior,x0):

@njit(parallel=True)
def get_ref_mt_weights(x0_extras, FLI_use, Ts, log_posterior_old, chosen_trial,ref_tries,log_prior_refs):
"""Helper function to quickly return multiple tries and their likelihoods fo MTMCMC"""
"""Helper function to quickly return multiple tries and their likelihoods fo MTMCMC
:param x0_extras: List of extra CWInfo objects for parallelizing multiple try
:param FLI_use: FastLikeInfo object
:param Ts: List of PT temperatures
:param log_posterior_old: Log posterior at old parameters
:param chosen_trial: Index of chosen trial
:param ref_tries: Parameters at a set of reference multiple tries for which we want to calculate the weights
:param log_prior_refs: Log prior values at reference points
:return ref_mt_weights: Reference point multiply try weights
:return log_ref_mt_norm_shift: Amount to shift the reference point multiple try weights (helps with using floating point precision efficiently)
"""
#NOTE isfinite does not work with fastmath enabled
#set up needed arrays
log_ref_mt_weights = np.zeros(cm.n_multi_try)
Expand Down

0 comments on commit b565a71

Please sign in to comment.