Skip to content

Commit

Permalink
Merge pull request #37 from siyuan-chen/master
Browse files Browse the repository at this point in the history
#35 adding general and simplified jump proposal draws
  • Loading branch information
stevertaylor authored Dec 18, 2019
2 parents 91a6a07 + 8b99931 commit 4ce5065
Showing 1 changed file with 132 additions and 0 deletions.
132 changes: 132 additions & 0 deletions enterprise_extensions/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,138 @@ def draw_from_signal_prior(self, x, iter, beta):

return q, float(lqxy)

def draw_from_par_prior(self, par_names):
# Preparing and comparing par_names with PTA parameters
par_names = np.atleast_1d(par_names)
par_list = []
name_list = []
for par_name in par_names:
pn_list = [n for n in self.plist if par_name in n]
if pn_list:
par_list.append(pn_list)
name_list.append(par_name)
if not par_list:
raise UserWarning("No parameter prior match found between {} and PTA.object."
.format(par_names))
par_list = np.concatenate(par_list,axis=None)

def draw(x, iter, beta):
"""Prior draw function generator for custom par_names.
par_names: list of strings
The function signature is specific to PTMCMCSampler.
"""

q = x.copy()
lqxy = 0

# randomly choose parameter
idx_name = np.random.choice(par_list)
idx = self.plist.index(idx_name)

# if vector parameter jump in random component
param = self.params[idx]
if param.size:
idx2 = np.random.randint(0, param.size)
q[self.pmap[str(param)]][idx2] = param.sample()[idx2]

# scalar parameter
else:
q[self.pmap[str(param)]] = param.sample()

# forward-backward jump probability
lqxy = (param.get_logpdf(x[self.pmap[str(param)]]) -
param.get_logpdf(q[self.pmap[str(param)]]))

return q, float(lqxy)

name_string = '_'.join(name_list)
draw.__name__ = 'draw_from_{}_prior'.format(name_string)
return draw

def draw_from_par_log_uniform(self, par_dict):
# Preparing and comparing par_dict.keys() with PTA parameters
par_list = []
name_list = []
for par_name in par_dict.keys():
pn_list = [n for n in self.plist if par_name in n and 'log' in n]
if pn_list:
par_list.append(pn_list)
name_list.append(par_name)
if not par_list:
raise UserWarning("No parameter dictionary match found between {} and PTA.object."
.format(par_dict.keys()))
par_list = np.concatenate(par_list,axis=None)

def draw(x, iter, beta):
"""log uniform prior draw function generator for custom par_names.
par_dict: dictionary with {"par_names":(lower bound,upper bound)}
{ "string":(float,float)}
The function signature is specific to PTMCMCSampler.
"""

q = x.copy()
lqxy = 0

# draw parameter from signal model
idx_name = np.random.choice(par_list)
idx = self.plist.index(idx_name)
q[idx] = np.random.uniform(par_dict[par_name][0],par_dict[par_name][1])

return q, 0

name_string = '_'.join(name_list)
draw.__name__ = 'draw_from_{}_log_uniform'.format(name_string)
return draw

def draw_from_signal(self, signal_names):
# Preparing and comparing signal_names with PTA signals
signal_names = np.atleast_1d(signal_names)
signal_list = []
name_list = []
for signal_name in signal_names:
try:
param_list = self.snames[signal_name]
signal_list.append(param_list)
name_list.append(signal_name)
except:
pass
if not signal_list:
raise UserWarning("No signal match found between {} and PTA.object!"
.format(signal_names))
signal_list = np.concatenate(signal_list,axis=None)

def draw(x, iter, beta):
"""Signal draw function generator for custom signal_names.
signal_names: list of strings
The function signature is specific to PTMCMCSampler.
"""

q = x.copy()
lqxy = 0

# draw parameter from signal model
param = np.random.choice(signal_list)
if param.size:
idx2 = np.random.randint(0, param.size)
q[self.pmap[str(param)]][idx2] = param.sample()[idx2]

# scalar parameter
else:
q[self.pmap[str(param)]] = param.sample()

# forward-backward jump probability
lqxy = (param.get_logpdf(x[self.pmap[str(param)]]) -
param.get_logpdf(q[self.pmap[str(param)]]))

return q, float(lqxy)

name_string = '_'.join(name_list)
draw.__name__ = 'draw_from_{}_signal'.format(name_string)
return draw


def get_global_parameters(pta):
"""Utility function for finding global parameters."""
Expand Down

0 comments on commit 4ce5065

Please sign in to comment.