Skip to content

Commit

Permalink
Add option to show marginalized posteriors in ParametersPlotter (stil…
Browse files Browse the repository at this point in the history
…l WIP)
  • Loading branch information
aymgal committed Nov 16, 2023
1 parent 53610eb commit 155778e
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 20 deletions.
108 changes: 89 additions & 19 deletions coolest/api/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,14 +424,26 @@ class ParametersPlotter(object):
A list of paths matching the coolest files in 'point_estimate_objs'.
ref_coolest_names : array, optional
A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
posterior_bool_list : list
posterior_bool_list : list, optional
List of bool to toggle errorbars on point-estimate values
colors : list
colors : list, optional
List of pyplot color names to associate to each coolest model.
add_multivariate_margin_samples : bool, optional
If True, will append to the list of compared models
a new chain that is resampled from the multi-variate normal distribution,
where its covariance matrix is computed from the marginalization of
all samples from all models. By default False.
num_samples_per_model_margin : int, optional
Number of samples to (randomly) draw from each model samples to concatenate
before estimating the multi-variate normal marginalization.
"""

np.random.seed(598237) # fix the random seed for reproducibility

def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, coolest_names=None,
ref_coolest_objects=None, ref_coolest_directories=None, ref_coolest_names=None,
posterior_bool_list=None, colors=None):
posterior_bool_list=None, colors=None,
add_multivariate_margin_samples=False, num_samples_per_model_margin=5_000):
self.parameter_id_list = parameter_id_list
self.coolest_objects = coolest_objects
self.coolest_directories = coolest_directories
Expand All @@ -446,16 +458,22 @@ def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None,
self.num_models = len(self.coolest_objects)
self.num_params = len(self.parameter_id_list)
if colors is None:
colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_params))
colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
self.colors = colors
self.linestyles = ['--', ':', '-.', '-']
self.markers = ['s', '^', 'o', '*']

self._add_margin_samples = add_multivariate_margin_samples
self._ns_per_model_margin = num_samples_per_model_margin
self._color_margin = 'black'
self._label_margin = "Marginalized multivariate samples"

# self.posterior_bool_list = posterior_bool_list
# self.param_lens, self.param_source = util.split_lens_source_params(
# self.coolest_objects, self.coolest_names, lens_light=False)

def init_getdist(self, shift_sample_list=None, settings_mcsamples=None):
def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
add_multivariate_margin_samples=False):
"""Initializes the getdist plotter.
Parameters
Expand Down Expand Up @@ -492,6 +510,7 @@ def init_getdist(self, shift_sample_list=None, settings_mcsamples=None):
point_estimates.append(values)

mcsamples = []
samples_margin, weights_margin = None, None
for i in range(self.num_models):
chain_file = os.path.join(self.coolest_directories[i],self.coolest_objects[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object

Expand Down Expand Up @@ -529,20 +548,49 @@ def init_getdist(self, shift_sample_list=None, settings_mcsamples=None):
if shift_sample_list[i] is not None:
for param_id, value in shift_sample_list[i].items():
sample_par_values[:, self.parameter_id_list.index(param_id)] += value
print(f"INFO: posterior for parameter '{param_id}' from model '{self.coolest_names[i]}' "
f"has been shifted by {value}.")
logging.info(f"posterior for parameter '{param_id}' from model '{self.coolest_names[i]}' "
f"has been shifted by {value}.")

# Clean-up the probability weights
mypost = np.array(samples['probability_weights'])
min_non_zero = np.min(mypost[np.nonzero(mypost)])
sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
sample_prob_weight = np.where(mypost<min_non_zero, min_non_zero, mypost)
#sample_prob_weight = mypost

# Create MCSamples object
mysample = MCSamples(samples=sample_par_values,names=self.parameter_id_list,labels=labels,settings=settings_mcsamples)
mysample = MCSamples(samples=sample_par_values, names=self.parameter_id_list,
labels=labels, settings=settings_mcsamples)
mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
mcsamples.append(mysample)

if self._add_margin_samples: # concatenate samples if required
# draw sample indices
ns_tot = sample_par_values.shape[0]
if self._ns_per_model_margin > ns_tot:
logging.warning(f"The number of samples for model '{self.coolest_names[i]}' "
f"is smaller than the number of samples needed to perform marginalization!")
ns_draw = ns_tot
else:
ns_draw = self._ns_per_model_margin
indices = np.random.choice(np.arange(ns_tot, dtype=int), size=ns_draw, replace=False)
# concatenate the drawn samples and their weights
if i == 0:
samples_margin = np.copy(sample_par_values[indices, :])
weights_margin = np.copy(sample_prob_weight[indices])
else:
samples_margin = np.concatenate([samples_margin, sample_par_values[indices, :]], axis=0)
weights_margin = np.concatenate([weights_margin, sample_prob_weight[indices]])

if self._add_margin_samples:
# samples_margin = util.resample_multivariate_normal(
# samples_margin, num_samples=10_000 #, ddof=0, aweights=weights_margin,
# )
mysample_margin = MCSamples(samples=samples_margin, names=self.parameter_id_list,
labels=self._label_margin,
settings=settings_mcsamples)
mysample_margin.reweightAddingLogLikes(-np.log(weights_margin))
mcsamples.append(mysample_margin)

self.mcsamples = mcsamples
self.ref_values = point_estimates
self.ref_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
Expand Down Expand Up @@ -571,16 +619,20 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
GetDistPlotter
Instance of GetDistPlotter corresponding to the figure
"""
line_args, contours_lws, contour_ls, colors, legend_labels \
= self._prepare_getdist_plot(linewidth_hist, lw_cont=linewidth_cont)

# Make the plot
g = plots.get_subplot_plotter(subplot_size=subplot_size)
g.triangle_plot(self.mcsamples,
params=self.parameter_id_list,
legend_labels=self.coolest_names,
legend_labels=legend_labels,
filled=filled_contours,
colors=self.colors,
line_args=[{'ls':'-', 'lw': linewidth_hist, 'color': c} for c in self.colors],
colors=colors,
line_args=line_args,
contour_colors=self.colors,
contours_lws=linewidth_cont)
contours_lws=contours_lws,
contour_ls=contour_ls)

# Add marker lines and points
for k in range(0, len(self.ref_values)):
Expand Down Expand Up @@ -638,16 +690,18 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
GetDistPlotter
Instance of GetDistPlotter corresponding to the figure
"""
line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)

if legend_ncol is None:
legend_ncol = 3
# Make the plot
g = plots.get_subplot_plotter(subplot_size=subplot_size)
g.rectangle_plot(x_param_ids, y_param_ids, roots=self.mcsamples,
legend_labels=self.coolest_names,
legend_labels=legend_labels,
filled=filled_contours,
colors=self.colors,
colors=colors,
legend_ncol=legend_ncol,
line_args=[{'ls':'-', 'lw': linewidth, 'color': c} for c in self.colors],
line_args=line_args,
contour_colors=self.colors)
for k in range(len(self.ref_markers)):
g.add_param_markers(self.ref_markers[k], color='black', ls=self.linestyles[k], lw=linewidth)
Expand Down Expand Up @@ -680,6 +734,8 @@ def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, li
GetDistPlotter
Instance of GetDistPlotter corresponding to the figure
"""
line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)

if num_columns is None:
num_columns = self.num_models//2+1
if legend_ncol is None:
Expand All @@ -688,10 +744,10 @@ def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, li
g = plots.get_subplot_plotter(subplot_size=subplot_size)
g.plots_1d(self.mcsamples,
params=self.parameter_id_list,
legend_labels=self.coolest_names,
colors=self.colors,
legend_labels=legend_labels,
colors=colors,
share_y=True,
line_args=[{'ls':'-', 'lw': linewidth, 'color': c} for c in self.colors],
line_args=line_args,
nx=num_columns, legend_ncol=legend_ncol,
)
for k in range(len(self.ref_values)):
Expand Down Expand Up @@ -787,6 +843,20 @@ def plotting_routine(self, param_dict, idx_file=0):
plt.show()
return f, ax

def _prepare_getdist_plot(self, lw, lw_cont=None):
line_args = [{'ls':'-', 'lw': lw, 'color': c} for c in self.colors]
lw_conts = [lw_cont]*self.num_models
ls_conts = ['-']*self.num_models
legend_labels = copy.deepcopy(self.coolest_names)
colors = copy.deepcopy(self.colors)
if self._add_margin_samples:
line_args.append({'ls': '-.', 'lw': lw+1, 'alpha': 0.8, 'color': self._color_margin})
ls_conts.append('-.')
if lw_cont is not None: lw_conts.append(lw_cont+1)
legend_labels.append(self._label_margin)
colors.append(self._color_margin)
return line_args, lw_conts, ls_conts, colors, legend_labels

# def plot_corner(parameter_id_list,
# chain_objs, chain_dirs, chain_names=None,
# point_estimate_objs=None, point_estimate_dirs=None, point_estimate_names=None,
Expand Down
11 changes: 10 additions & 1 deletion coolest/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,4 +438,13 @@ def find_all_lens_lines(coordinates, composable_lens):
crit_lines = find_critical_lines(coordinates, mag_map)
caustics = find_caustics(crit_lines, composable_lens)
return crit_lines, caustics



def resample_multivariate_normal(samples, num_samples=5_000, **kwargs_cov):
"""Resample following multi-variate normal distribution"""
mean = np.mean(samples, axis=0)
cov = np.cov(samples.T, **kwargs_cov)
num_params = samples.shape[1]
resampled = np.random.multivariate_normal(
mean=mean, cov=cov, size=(int(num_samples/num_params), num_params)).reshape((-1, num_params))
return resampled

0 comments on commit 155778e

Please sign in to comment.