diff --git a/enterprise_extensions/model_utils.py b/enterprise_extensions/model_utils.py index ac0d8fd5..01bd2aeb 100644 --- a/enterprise_extensions/model_utils.py +++ b/enterprise_extensions/model_utils.py @@ -552,6 +552,138 @@ def draw_from_signal_prior(self, x, iter, beta): return q, float(lqxy) + def draw_from_par_prior(self, par_names): + # Preparing and comparing par_names with PTA parameters + par_names = np.atleast_1d(par_names) + par_list = [] + name_list = [] + for par_name in par_names: + pn_list = [n for n in self.plist if par_name in n] + if pn_list: + par_list.append(pn_list) + name_list.append(par_name) + if not par_list: + raise UserWarning("No parameter prior match found between {} and PTA.object." + .format(par_names)) + par_list = np.concatenate(par_list,axis=None) + + def draw(x, iter, beta): + """Prior draw function generator for custom par_names. + par_names: list of strings + + The function signature is specific to PTMCMCSampler. + """ + + q = x.copy() + lqxy = 0 + + # randomly choose parameter + idx_name = np.random.choice(par_list) + idx = self.plist.index(idx_name) + + # if vector parameter jump in random component + param = self.params[idx] + if param.size: + idx2 = np.random.randint(0, param.size) + q[self.pmap[str(param)]][idx2] = param.sample()[idx2] + + # scalar parameter + else: + q[self.pmap[str(param)]] = param.sample() + + # forward-backward jump probability + lqxy = (param.get_logpdf(x[self.pmap[str(param)]]) - + param.get_logpdf(q[self.pmap[str(param)]])) + + return q, float(lqxy) + + name_string = '_'.join(name_list) + draw.__name__ = 'draw_from_{}_prior'.format(name_string) + return draw + + def draw_from_par_log_uniform(self, par_dict): + # Preparing and comparing par_dict.keys() with PTA parameters + par_list = [] + name_list = [] + for par_name in par_dict.keys(): + pn_list = [n for n in self.plist if par_name in n and 'log' in n] + if pn_list: + par_list.append(pn_list) + name_list.append(par_name) + if not par_list: + raise UserWarning("No parameter dictionary match found between {} and PTA.object." + .format(par_dict.keys())) + par_list = np.concatenate(par_list,axis=None) + + def draw(x, iter, beta): + """log uniform prior draw function generator for custom par_names. + par_dict: dictionary with {"par_names":(lower bound,upper bound)} + { "string":(float,float)} + + The function signature is specific to PTMCMCSampler. + """ + + q = x.copy() + lqxy = 0 + + # draw parameter from signal model + idx_name = np.random.choice(par_list) + idx = self.plist.index(idx_name) + q[idx] = np.random.uniform(par_dict[par_name][0],par_dict[par_name][1]) + + return q, 0 + + name_string = '_'.join(name_list) + draw.__name__ = 'draw_from_{}_log_uniform'.format(name_string) + return draw + + def draw_from_signal(self, signal_names): + # Preparing and comparing signal_names with PTA signals + signal_names = np.atleast_1d(signal_names) + signal_list = [] + name_list = [] + for signal_name in signal_names: + try: + param_list = self.snames[signal_name] + signal_list.append(param_list) + name_list.append(signal_name) + except: + pass + if not signal_list: + raise UserWarning("No signal match found between {} and PTA.object!" + .format(signal_names)) + signal_list = np.concatenate(signal_list,axis=None) + + def draw(x, iter, beta): + """Signal draw function generator for custom signal_names. + signal_names: list of strings + + The function signature is specific to PTMCMCSampler. + """ + + q = x.copy() + lqxy = 0 + + # draw parameter from signal model + param = np.random.choice(signal_list) + if param.size: + idx2 = np.random.randint(0, param.size) + q[self.pmap[str(param)]][idx2] = param.sample()[idx2] + + # scalar parameter + else: + q[self.pmap[str(param)]] = param.sample() + + # forward-backward jump probability + lqxy = (param.get_logpdf(x[self.pmap[str(param)]]) - + param.get_logpdf(q[self.pmap[str(param)]])) + + return q, float(lqxy) + + name_string = '_'.join(name_list) + draw.__name__ = 'draw_from_{}_signal'.format(name_string) + return draw + def get_global_parameters(pta): """Utility function for finding global parameters."""