2
2
import numpy as np
3
3
from scipy .stats import uniform
4
4
import matplotlib .pyplot as plt
5
+ import math
5
6
6
7
7
8
def estimate_multivariate_density_w_GMM (samples , name , k_max = 30 , verbose = False , plot = False ):
@@ -56,8 +57,8 @@ def estimate_multivariate_density_w_GMM(samples, name, k_max=30, verbose=False,
56
57
plt .title (f'{ name } ' )
57
58
plt .savefig (f'{ name } .png' )
58
59
if verbose :
59
- print (f"AIC min: k=gmm_best_aic_idx+1, AIC={ aics [gmm_best_aic_idx ]} " )
60
- print (f"BIC min: k=gmm_best_bic_idx+1, AIC={ bics [gmm_best_bic_idx ]} " )
60
+ print (f"AIC min: k={ gmm_best_aic_idx + 1 } , AIC={ aics [gmm_best_aic_idx ]} " )
61
+ print (f"BIC min: k={ gmm_best_bic_idx + 1 } , AIC={ bics [gmm_best_bic_idx ]} " )
61
62
return gmm_best_aic , gmm_best_bic
62
63
63
64
@@ -84,7 +85,6 @@ def kl_divergence_gmm_uniform(gmm, unfiform_prior_bounds, name, n_samples=10**6,
84
85
Samples outside the specified bounds for the uniform distribution are discarded.
85
86
"""
86
87
87
-
88
88
samples = gmm .sample (n_samples )[0 ]
89
89
valid_samples = np .all ([(samples [:, i ] >= r [0 ]) & (samples [:, i ] <= r [1 ]) for i , r in enumerate (unfiform_prior_bounds )], axis = 0 )
90
90
samples = samples [valid_samples ]
@@ -93,4 +93,160 @@ def kl_divergence_gmm_uniform(gmm, unfiform_prior_bounds, name, n_samples=10**6,
93
93
kl_divergence = np .mean (log_gmm_pdf - log_uniform_pdf )
94
94
if verbose :
95
95
print (f"KL divergence of { name } = { kl_divergence } " )
96
- return kl_divergence
96
+ return kl_divergence
97
+
98
+
99
+ def plot_1D_distributions (sample_arrays , sample_labels , parameter_names , parameter_ranges , parameter_nominals , bins = 100 , title = "1D Parameter Distribution" ):
100
+ """
101
+ Plots 1D parameter distributions with overlay of different samples and returns the figure object.
102
+
103
+ Args:
104
+ sample_arrays (list[np.ndarray]): List of 2D sample arrays to be plotted, where each row represents a sample and each column represents a parameter.
105
+ sample_labels (list[str]): List of labels corresponding to each sample array.
106
+ parameter_names (list[str]): List of parameter names.
107
+ parameter_ranges (list[tuple]): List of (min, max) ranges for each parameter.
108
+ parameter_nominals (list[float]): List of nominal values for each parameter.
109
+ title (str, optional): Title of the plot. Default is "1D Parameter Distribution".
110
+
111
+ Returns:
112
+ matplotlib.figure.Figure: The figure object containing the plotted distributions.
113
+
114
+ Raises:
115
+ AssertionError: If the number of sample arrays does not match the number of labels.
116
+ """
117
+
118
+ # Ensure the number of labels match the number of sample arrays
119
+ assert len (sample_arrays ) == len (sample_labels ), "Mismatch between number of sample arrays and labels."
120
+
121
+ num_cols = math .ceil (math .sqrt (len (parameter_names )))
122
+ num_rows = math .ceil (len (parameter_names ) / num_cols )
123
+
124
+ fig , axs = plt .subplots (num_rows , num_cols , figsize = (3 * num_cols , 2 * num_rows ))
125
+ axs = axs .flatten ()
126
+
127
+ for param_idx in range (len (parameter_names )):
128
+ for i , sample_array in enumerate (sample_arrays ):
129
+ axs [param_idx ].hist (sample_array [:, param_idx ], bins = bins , alpha = 0.5 , density = True , histtype = 'step' , label = sample_labels [i ], range = parameter_ranges [param_idx ])
130
+ axs [param_idx ].axvline (parameter_nominals [param_idx ], linestyle = '--' , color = 'k' , linewidth = 1 )
131
+ axs [param_idx ].set_xlabel (parameter_names [param_idx ])
132
+ axs [param_idx ].set_ylabel ('Density' )
133
+
134
+ axs [0 ].legend ()
135
+ fig .suptitle (title )
136
+ plt .tight_layout ()
137
+ plt .subplots_adjust (top = 0.9 )
138
+ return fig
139
+
140
+
141
+ def plot_2D_corner (sample_arrays , sample_labels , parameter_names , parameter_ranges , parameter_nominals = None , bins = 100 , title = "Corner Plot" ):
142
+ """
143
+ Plots a 2D corner plot with 2D density histograms off-diagonal and 1D histograms on the diagonal.
144
+
145
+ Args:
146
+ sample_arrays (list[np.ndarray]): List of 2D sample arrays to be plotted, where each row represents a sample and each column represents a parameter.
147
+ sample_labels (list[str]): List of labels corresponding to each sample array.
148
+ parameter_names (list[str]): List of parameter names.
149
+ parameter_ranges (list[tuple]): List of (min, max) ranges for each parameter.
150
+ parameter_nominals (list[float], optional): List of nominal (reference) values for each parameter.
151
+ bins (int or list): Number of bins or a list of bin edges for the histograms.
152
+ title (str, optional): Title of the plot. Default is "Corner Plot".
153
+
154
+ Returns:
155
+ matplotlib.figure.Figure: The figure object containing the plotted distributions.
156
+
157
+ Raises:
158
+ AssertionError: If the number of sample arrays does not match the number of labels.
159
+ """
160
+
161
+ # Ensure the number of labels match the number of sample arrays
162
+ assert len (sample_arrays ) == len (sample_labels ), "Mismatch between number of sample arrays and labels."
163
+
164
+ num_params = len (parameter_names )
165
+ fig , axs = plt .subplots (num_params , num_params , figsize = (3 * num_params , 3 * num_params ))
166
+
167
+ for row in range (num_params ):
168
+ for col in range (num_params ):
169
+ ax = axs [row , col ]
170
+
171
+ # Hide plots in the upper triangle
172
+ if row < col :
173
+ ax .axis ('off' )
174
+ continue
175
+
176
+ # Diagonal: 1D histograms
177
+ if row == col :
178
+ for i , sample_array in enumerate (sample_arrays ):
179
+ ax .hist (sample_array [:, col ], bins = bins , alpha = 0.5 , density = True , histtype = 'step' , label = sample_labels [i ], range = parameter_ranges [col ])
180
+ ax .set_xlim (* parameter_ranges [col ])
181
+ ax .set_xlabel (parameter_names [col ])
182
+ if parameter_nominals :
183
+ ax .axvline (parameter_nominals [col ], linestyle = '--' , color = 'k' , linewidth = 1 )
184
+
185
+ # Off-diagonal: 2D histograms
186
+ else :
187
+ for i , sample_array in enumerate (sample_arrays ):
188
+ hist2d_params = {
189
+ "bins" : bins ,
190
+ "range" : [parameter_ranges [col ], parameter_ranges [row ]],
191
+ "cmap" : 'Blues' ,
192
+ "density" : True
193
+ }
194
+ ax .hist2d (sample_array [:, col ], sample_array [:, row ], ** hist2d_params )
195
+ ax .set_xlim (* parameter_ranges [col ])
196
+ ax .set_ylim (* parameter_ranges [row ])
197
+ ax .set_xlabel (parameter_names [col ])
198
+ ax .set_ylabel (parameter_names [row ])
199
+ if parameter_nominals :
200
+ ax .axvline (parameter_nominals [col ], linestyle = '--' , color = 'k' , linewidth = 1 )
201
+ ax .axhline (parameter_nominals [row ], linestyle = '--' , color = 'k' , linewidth = 1 )
202
+
203
+ # We set the legend on one of the diagonal plots for compactness
204
+ axs [0 ,0 ].legend (loc = 'upper right' )
205
+ fig .suptitle (title )
206
+ plt .tight_layout ()
207
+ plt .subplots_adjust (top = 0.95 )
208
+ return fig
209
+
210
+
211
+ def plot_random_sample_predictions (samples , data_gen_func , observed_data , N , gen_func_args = None , title = None ):
212
+ """
213
+ Draws N random samples from the given sample array, generates predicted datasets
214
+ for each sample, and plots them against the observed dataset.
215
+
216
+ Args:
217
+ samples (np.ndarray): Sample array where each row is a sample.
218
+ data_gen_func (function): Data generation function that takes in a sample and additional arguments.
219
+ observed_data (np.ndarray): Observed data set for reference.
220
+ N (int): Number of random samples to be drawn.
221
+ gen_func_args (dict, optional): Additional arguments to be passed to the data generation function.
222
+
223
+ Returns:
224
+ matplotlib.figure.Figure: The figure object containing the plotted predictions and observed data.
225
+ """
226
+
227
+ if gen_func_args is None :
228
+ gen_func_args = {}
229
+
230
+ # Check if the number of samples requested is valid
231
+ total_samples = samples .shape [0 ]
232
+ if N <= 0 or N > total_samples :
233
+ raise ValueError (f"Invalid number of samples requested: { N } . It should be between 1 and { total_samples } ." )
234
+
235
+ # Select N random samples
236
+ random_samples = samples [np .random .choice (total_samples , N , replace = False )]
237
+
238
+ plt .figure (figsize = (12 , 6 ))
239
+
240
+ # Generate and plot predicted data for each random sample
241
+ for sample in random_samples :
242
+ predicted_data = data_gen_func (sample , ** gen_func_args )
243
+ plt .plot (predicted_data , 'b-' , alpha = 0.25 )
244
+
245
+ # Plot observed data for comparison
246
+ plt .plot (observed_data , 'ro' , label = "Observed Data" , alpha = 0.45 )
247
+ plt .legend ()
248
+ if title :
249
+ plt .title (f"{ title } " )
250
+ else :
251
+ plt .title (f"{ N } Randomly Selected Predictions vs Observed Data" )
252
+ return plt .gcf ()
0 commit comments