diff --git a/coolest/api/plotting.py b/coolest/api/plotting.py index 8df2e0d..a5ea28c 100644 --- a/coolest/api/plotting.py +++ b/coolest/api/plotting.py @@ -430,6 +430,8 @@ class ParametersPlotter(object): List of bool to toggle errorbars on point-estimate values colors : list, optional List of pyplot color names to associate to each coolest model. + linestyles : list, optional + List of pyplot linesyles to associate to each coolest model. add_multivariate_margin_samples : bool, optional If True, will append to the list of compared models a new chain that is resampled from the multi-variate normal distribution, @@ -462,6 +464,8 @@ def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, if colors is None: colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models)) self.colors = colors + if linestyles is None: + linestyles = ['-']*self.num_models self.linestyles = linestyles self.ref_linestyles = ['--', ':', '-.', '-'] self.ref_markers = ['s', '^', 'o', '*'] @@ -597,7 +601,7 @@ def get_margin_mcsamples_getdist(self): def plot_triangle_getdist(self, filled_contours=True, angles_range=None, linewidth_hist=2, linewidth_cont=2, linewidth_margin=4, marker_linewidth=2, marker_size=15, - axes_labelsize=12, legend_fontsize=14, + axes_labelsize=None, legend_fontsize=None, **subplot_kwargs): """Corner array of subplots using getdist.triangle_plot method. @@ -634,8 +638,10 @@ def plot_triangle_getdist(self, filled_contours=True, angles_range=None, # Make the plot g = plots.get_subplot_plotter(**subplot_kwargs) - g.settings.legend_fontsize = legend_fontsize - g.settings.axes_labelsize = axes_labelsize + if legend_fontsize is not None: + g.settings.legend_fontsize = legend_fontsize + if axes_labelsize is not None: + g.settings.axes_labelsize = axes_labelsize g.triangle_plot( self._mcsamples, params=self.parameter_id_list, @@ -685,8 +691,9 @@ def plot_triangle_getdist(self, filled_contours=True, angles_range=None, return g def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1, - legend_ncol=None, filled_contours=True, linewidth=1, - marker_size=15, **subplot_kwargs): + legend_ncol=None, legend_fontsize=None, + filled_contours=True, linewidth=1, + marker_size=15, axes_labelsize=None, **subplot_kwargs): """Array of (2D contours) subplots using getdist.rectangle_plot method. Parameters @@ -711,14 +718,18 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1, if legend_ncol is None: legend_ncol = 3 # Make the plot - g = plots.get_subplot_plotter(**subplot_kwargs) + g = plots.get_subplot_plotter(**subplot_kwargs) + if legend_fontsize is not None: + g.settings.legend_fontsize = legend_fontsize + if axes_labelsize is not None: + g.settings.axes_labelsize = axes_labelsize g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples, - legend_labels=legend_labels, - filled=filled_contours, - colors=colors, - legend_ncol=legend_ncol, - line_args=line_args, - contour_colors=self.colors) + filled=filled_contours, + colors=colors, + legend_ncol=legend_ncol, + legend_labels=legend_labels, + line_args=line_args, + contour_colors=self.colors) for k in range(len(self.ref_values)): g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth) for j, key_x in enumerate(x_param_ids): @@ -729,7 +740,9 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1, g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k]) return g - def plot_1d_getdist(self, num_columns=None, legend_ncol=None, linewidth=1, **subplot_kwargs): + def plot_1d_getdist(self, num_columns=None, legend_ncol=None, + legend_fontsize=None, axes_labelsize=None, + linewidth=1, **subplot_kwargs): """Array of 1D histogram subplots using getdist.plots_1d method. Parameters @@ -757,7 +770,11 @@ def plot_1d_getdist(self, num_columns=None, legend_ncol=None, linewidth=1, **sub if legend_ncol is None: legend_ncol = 3 # Make the plot - g = plots.get_subplot_plotter(**subplot_kwargs) + g = plots.get_subplot_plotter(**subplot_kwargs) + if legend_fontsize is not None: + g.settings.legend_fontsize = legend_fontsize + if axes_labelsize is not None: + g.settings.axes_labelsize = axes_labelsize g.plots_1d(self._mcsamples, params=self.parameter_id_list, legend_labels=legend_labels,