diff --git a/src/scripts/engine.py b/src/scripts/engine.py index c41d340..f8c2f53 100644 --- a/src/scripts/engine.py +++ b/src/scripts/engine.py @@ -189,6 +189,9 @@ def _eval_step( test_ppl = perplexity(test_avg_ll, self.metadata['num_variables']) self.logger.info(f"[{self.args.dataset}] Epoch {epoch_idx}, Test ppl: {test_ppl:.03f}") self.logger.log_scalar('Test/ppl', test_ppl, step=epoch_idx) + if self._log_distribution: + self.logger.log_best_distribution( + self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device) metrics['best_valid_epoch'] = epoch_idx metrics['best_valid_avg_ll'] = valid_avg_ll metrics['best_valid_std_ll'] = valid_std_ll @@ -357,7 +360,7 @@ def run(self): if self._log_distribution: self.logger.save_array(self.metadata['hmap'], 'gt.npy') - self.logger.log_distribution( + self.logger.log_step_distribution( self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device) # The train loop @@ -391,7 +394,7 @@ def run(self): else (max(1, int(2e-1 * self.args.log_frequency)) if epoch_idx == 2 else self.args.log_frequency)) == 0: if self._log_distribution: - self.logger.log_distribution( + self.logger.log_step_distribution( self.model, self.args.discretize, lim=self.metadata['domains'], device=self._device) opt_counter += 1 if diverged: diff --git a/src/scripts/logger.py b/src/scripts/logger.py index 3d10c85..fa36325 100644 --- a/src/scripts/logger.py +++ b/src/scripts/logger.py @@ -4,12 +4,13 @@ import numpy as np import torch import wandb +from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from PIL import Image as pillow from graphics.distributions import bivariate_pmf_heatmap, bivariate_pdf_heatmap -from pcs.models import PC +from pcs.models import PC, TensorizedPC class Logger: @@ -33,6 +34,7 @@ def __init__( wandb_kwargs = dict() self._setup_wandb(wandb_path, **wandb_kwargs) + self._best_distribution = None self._logged_distributions = list() self._logged_wcoords = list() @@ -105,7 +107,21 @@ def log_hparams( if wandb.run: wandb.run.summary.update(metric_dict) - def log_distribution( + def log_best_distribution( + self, + model: PC, + discretized: bool, + lim: Tuple[Tuple[Union[float, int], Union[float, int]], Tuple[Union[float, int], Union[float, int]]], + device: Optional[Union[str, torch.device]] = None + ): + xlim, ylim = lim + if discretized: + dist_hmap = bivariate_pmf_heatmap(model, xlim, ylim, device=device) + else: + dist_hmap = bivariate_pdf_heatmap(model, xlim, ylim, device=device) + self._best_distribution = dist_hmap.astype(np.float32, copy=False) + + def log_step_distribution( self, model: PC, discretized: bool, @@ -121,7 +137,8 @@ def log_distribution( def close(self): if self._logged_distributions: - self.save_array(np.stack(self._logged_distributions, axis=0), 'distribution.npy') + self.save_array(self._best_distribution, 'distbest.npy') + self.save_array(np.stack(self._logged_distributions, axis=0), 'diststeps.npy') if self._logged_wcoords: self.save_array(np.stack(self._logged_wcoords, axis=0), 'wcoords.npy') if self._tboard_writer is not None: diff --git a/src/scripts/plots/gpt2dist/lines.py b/src/scripts/plots/gpt2dist/lines.py index 0d3f628..a201a5a 100644 --- a/src/scripts/plots/gpt2dist/lines.py +++ b/src/scripts/plots/gpt2dist/lines.py @@ -89,21 +89,19 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: if rows_to_keep is not None: for r, vs in rows_to_keep.items(): model_df = model_df[model_df[r].isin(vs)] - model_df.to_csv(f'{model_name}-gpt2commongen-results.csv', index=None) + model_df.to_csv(f'gpt2commongen-results-{model_name}.csv', index=None) group_model_df = model_df.groupby(by=['init_method', 'learning_rate']) should_label = True metrics[model_name] = defaultdict(list) for j, hparam_df in group_model_df: ms, ps = hparam_df[metric].tolist(), hparam_df['num_components'].tolist() - if len(np.unique(ms)) < num_points or len(np.unique(ps)) < num_points: + if len(np.unique(ps)) < num_points: continue ms = np.array(ms, dtype=np.float64) ps = np.array(ps, dtype=np.int64) sort_indices = np.argsort(ps) ps = ps[sort_indices] ms = ms[sort_indices] - ps = ps[:num_points] - ms = ms[:num_points] for p, m in zip(ps.tolist(), ms.tolist()): metrics[model_name][p].append(m) if not args.median: @@ -130,7 +128,7 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: assert len(model_names) == 2 model_a, model_b = model_names spvalues = defaultdict(lambda: defaultdict(dict)) - for ts in ['mannwithneyu', 'ttest']: + for ts in ['mannwithneyu']: for al in ['greater']: for k in sorted(metrics[model_a].keys() & metrics[model_b].keys()): lls_a = metrics[model_a][k] @@ -141,7 +139,7 @@ def filter_dataframe(df: pd.DataFrame, filter_dict: dict) -> pd.DataFrame: s, p = stats.ttest_ind(lls_b, lls_a, alternative=al) else: assert False, "Should not happen :(" - spvalues[ts][al][k] = (round(s, 3), round(p, 4)) + spvalues[ts][al][k] = (round(s, 3), round(p, 7)) print(spvalues) #if args.train: diff --git a/src/scripts/plots/ring/distgif.py b/src/scripts/plots/ring/distgif.py index 98b0c97..f92c016 100644 --- a/src/scripts/plots/ring/distgif.py +++ b/src/scripts/plots/ring/distgif.py @@ -18,14 +18,14 @@ parser.add_argument('--drop-last-frames', type=int, default=0, help="The number of last frames to drop") """ -python -m scripts.plots.ring.distgif checkpoints/loss-landscape --drop-last-frames 164 +python -m scripts.plots.ring.distgif checkpoints/gaussian-ring --drop-last-frames 224 """ if __name__ == '__main__': def to_rgb(x: np.ndarray, cmap: cm.ScalarMappable, cmap_transform: Callable[[np.ndarray], np.ndarray]) -> np.ndarray: #x = x[51:-50, 51:-50] - x = (cmap.to_rgba(cmap_transform(x)) * 255.0).astype(np.uint8)[..., :-1] + x = (cmap.to_rgba(cmap_transform(x.T)) * 255.0).astype(np.uint8)[..., :-1] if x.shape[0] != args.gif_size or x.shape[1] != args.gif_size: x = cv2.resize(x, dsize=(args.gif_size, args.gif_size), interpolation=cv2.INTER_CUBIC) return x @@ -49,7 +49,7 @@ def to_rgb_image(x: np.ndarray) -> pillow.Image: ] gt_array = np.load(os.path.join(checkpoint_paths[0], 'gt.npy')) gt_array = np.broadcast_to(gt_array, (args.max_num_frames, gt_array.shape[0], gt_array.shape[1])) - arrays = map(lambda p: np.load(os.path.join(p, 'distribution.npy')), checkpoint_paths) + arrays = map(lambda p: np.load(os.path.join(p, 'diststeps.npy')), checkpoint_paths) if args.drop_last_frames > 0: arrays = map(lambda a: a[:-args.drop_last_frames], arrays) arrays = [gt_array] + list(arrays) diff --git a/src/scripts/plots/ring/ellipses.py b/src/scripts/plots/ring/ellipses.py index 417cc84..afa58d6 100644 --- a/src/scripts/plots/ring/ellipses.py +++ b/src/scripts/plots/ring/ellipses.py @@ -23,13 +23,13 @@ def ring_kde() -> np.ndarray: - splits = load_artificial_dataset('ring', num_samples=50000, dtype=np.dtype(np.float64)) + splits = load_artificial_dataset('ring', num_samples=500, dtype=np.dtype(np.float64)) data = np.concatenate(splits, axis=0) scaler = StandardScaler() data = scaler.fit_transform(data) data_min, data_max = np.min(data, axis=0), np.max(data, axis=0) - drange = np.abs(data_max - data_min) - data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05) + #drange = np.abs(data_max - data_min) + #data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05) xlim, ylim = [(data_min[i], data_max[i]) for i in range(len(data_min))] return kde_samples_hmap(data, xlim=xlim, ylim=ylim, bandwidth=0.16) @@ -52,7 +52,7 @@ def load_mixture( metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) model: TensorizedPC = setup_model(model_name, metadata, num_components=num_components) exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) - filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model_name, exp_id, 'model.pt') + filepath = os.path.join(args.checkpoint_path, 'ring', model_name, exp_id, 'model.pt') state_dict = torch.load(filepath, map_location='cpu') model.load_state_dict(state_dict['weights']) return model @@ -66,7 +66,7 @@ def load_pdf( batch_size: int = 64 ) -> np.ndarray: exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) - filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model, exp_id, 'pdf.npy') + filepath = os.path.join(args.checkpoint_path, 'ring', model, exp_id, 'distbest.npy') return np.load(filepath) @@ -155,41 +155,43 @@ def plot_pdf( models = [ 'MonotonicPC', 'MonotonicPC', - 'BornPC', - 'MAF', - 'NSF' + 'BornPC' ] - num_components = [2, 16, 2, 128, 128] - learning_rates = [5e-3, 5e-3, 4e-3, 1e-3, 1e-3] + num_components = [2, 16, 2] + learning_rates = [5e-3, 5e-3, 1e-3] exp_id_formats = [ 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', - 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN', - 'K{}_OAdam_LR{}_BS{}', - 'K{}_OAdam_LR{}_BS{}' + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN' ] + truth_pdf = ring_kde() + mixtures = [ load_mixture(m, eif, nc, lr) - for m, eif, nc, lr in zip(models[:3], exp_id_formats, num_components, learning_rates) - ] + [None, None] + for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) + ] pdfs = [ load_pdf(m, eif, nc, lr) for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) ] - vmax = np.max(pdfs) + vmax = np.max([truth_pdf] + pdfs) vmin = 0.0 metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True) - for idx, (p, pdf, m, nc) in enumerate(zip(mixtures, pdfs, models, num_components)): + data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components)) + for idx, (p, pdf, m, nc) in enumerate(data_pdfs): setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0) fig, ax = plt.subplots(1, 1) - title = f"{format_model_name(m, nc)}" if args.title else None + if args.title: + title = f"{format_model_name(m, nc)}" if p is not None else m + else: + title = None plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax) if p is not None: diff --git a/src/scripts/plots/ring/pdfs.py b/src/scripts/plots/ring/pdfs.py index f7e694d..218d900 100644 --- a/src/scripts/plots/ring/pdfs.py +++ b/src/scripts/plots/ring/pdfs.py @@ -2,48 +2,128 @@ import os.path from typing import Optional +import matplotlib as mpl import numpy as np +from scipy import special +import torch from matplotlib import pyplot as plt from sklearn.preprocessing import StandardScaler from datasets.loaders import load_artificial_dataset from graphics.distributions import kde_samples_hmap from graphics.utils import setup_tueplots +from pcs.models import TensorizedPC, PC, MonotonicPC +from scripts.utils import setup_model, setup_data_loaders parser = argparse.ArgumentParser( - description="PDFs plotter" + description="PDFs and ellipses plotter" ) parser.add_argument('--checkpoint-path', default='checkpoints', type=str, help="The checkpoints path") +parser.add_argument('--show-ellipses', default=False, action='store_true', + help="Whether to show the Gaussian components as ellipses") parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title") -def ring_kde() -> np.ndarray: - splits = load_artificial_dataset('ring', num_samples=50000, dtype=np.dtype(np.float64)) - data = np.concatenate(splits, axis=0) - scaler = StandardScaler() - data = scaler.fit_transform(data) - data_min, data_max = np.min(data, axis=0), np.max(data, axis=0) - drange = np.abs(data_max - data_min) - data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05) - xlim, ylim = [(data_min[i], data_max[i]) for i in range(len(data_min))] - return kde_samples_hmap(data, xlim=xlim, ylim=ylim, bandwidth=0.16) - - def format_model_name(m: str, num_components: int) -> str: if m == 'MonotonicPC': - return f"GMM ($K \! = \! {num_components}$)" + return f"GMM ($K \!\! = \!\! {num_components}$)" elif m == 'BornPC': - return f"NGMM ($K \! = \! {num_components}$)" + return f"NGMM ($K \!\! = \!\! {num_components}$)" return m -def load_pdf(model: str, exp_id: str) -> np.ndarray: - filepath = os.path.join(args.checkpoint_path, 'gaussian-ring', 'ring', model, exp_id, 'pdf.npy') +def load_mixture( + model_name: str, + exp_id_fmt: str, + num_components: int, + learning_rate: float = 5e-3, + batch_size: int = 64 +) -> TensorizedPC: + metadata, _ = setup_data_loaders('ring', 'datasets', 1) + model: TensorizedPC = setup_model(model_name, metadata, num_components=num_components) + exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) + filepath = os.path.join(args.checkpoint_path, 'ring', model_name, exp_id, 'model.pt') + state_dict = torch.load(filepath, map_location='cpu') + model.load_state_dict(state_dict['weights']) + return model + + +def load_pdf( + model: str, + exp_id_fmt: str, + num_components, + learning_rate: float = 5e-3, + batch_size: int = 64 +) -> np.ndarray: + exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) + filepath = os.path.join(args.checkpoint_path, 'ring', model, exp_id, 'distbest.npy') return np.load(filepath) -def plot_pdf(pdf: np.ndarray, ax: plt.Axes, vmin: Optional[float] = None, vmax: Optional[float] = None): +def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes): + mus = mixture.input_layer.mu[0, :, 0, :].detach().numpy() + covs = np.exp(2 * mixture.input_layer.log_sigma[0, :, 0, :].detach().numpy()) + num_components = mus.shape[-1] + mix_weights = mixture.layers[-1].weight[0, 0].detach().numpy() + if isinstance(mixture, MonotonicPC): + mix_weights = special.softmax(mix_weights) + mix_weights = mix_weights / np.max(mix_weights) + else: + mix_weights = -mix_weights / np.max(np.abs(mix_weights)) + for i in range(num_components): + if np.abs(mix_weights[i]) < 0.1: + continue + mu = mus[:, i] + cov = np.diag(covs[:, i]) + v, w = np.linalg.eigh(cov) + v = 2.0 * np.sqrt(2.0) * np.sqrt(v) + ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.7, fill=False) + ell_dot = mpl.patches.Circle(mu, radius=0.03, fill=True) + ell.set_color('red') + if isinstance(mixture, MonotonicPC): + #ell.set_alpha(mix_weights[i]) + #ell_dot.set_alpha(0.5 * mix_weights[i]) + ell_dot.set_color('red') + ell.set_alpha(0.775) + ell_dot.set_alpha(0.775) + else: + if mix_weights[i] <= 0.0: + #ell.set_alpha(min(1.0, 3 * np.abs(mix_weights[i]))) + ell.set_linestyle('dotted') + #ell_dot.set_alpha(0.5 * np.abs(mix_weights[i])) + #%ell_dot.set_color('red') + else: + #ell.set_alpha(mix_weights[i]) + #ell_dot.set_alpha(0.5 * mix_weights[i]) + ell_dot.set_color('red') + ell.set_alpha(0.85) + ell_dot.set_alpha(0.85) + ax.add_artist(ell) + ax.add_artist(ell_dot) + + +def plot_pdf( + pdf: np.ndarray, + metadata: dict, + ax: plt.Axes, vmin: + Optional[float] = None, + vmax: Optional[float] = None +): + pdf = pdf[8:-8, 8:-8] + + x_lim = metadata['domains'][0] + y_lim = metadata['domains'][1] + x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) + y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) + + x_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + y_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + xi, yi = np.mgrid[range(pdf.shape[0]), range(pdf.shape[1])] + xi = (xi + 0.5) / pdf.shape[0] + yi = (yi + 0.5) / pdf.shape[1] + xi = xi * (x_lim[1] - x_lim[0]) + x_lim[0] + yi = yi * (y_lim[1] - y_lim[0]) + y_lim[0] ax.pcolormesh(xi, yi, pdf, vmin=vmin, vmax=vmax) @@ -51,54 +131,67 @@ def plot_pdf(pdf: np.ndarray, ax: plt.Axes, vmin: Optional[float] = None, vmax: args = parser.parse_args() models = [ - 'Ground Truth', 'MonotonicPC', 'MonotonicPC', - 'BornPC', + 'BornPC' ] - exp_ids = [ - '', - 'RGran_R1_K2_D1_Lcp_OAdam_LR0.005_BS64_IU', - 'RGran_R1_K16_D1_Lcp_OAdam_LR0.005_BS64_IU', - 'RGran_R1_K2_D1_Lcp_OAdam_LR0.005_BS64_IN', + num_components = [2, 16, 2] + learning_rates = [5e-3, 5e-3, 1e-3] + + exp_id_formats = [ + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN' + ] + + truth_pdf = np.load( + os.path.join(args.checkpoint_path, 'ring', models[0], + exp_id_formats[0].format(num_components[0], learning_rates[0], 64), 'gt.npy') + ) + # truth_pdf = ring_kde() + + mixtures = [ + load_mixture(m, eif, nc, lr) + for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) ] - truth_pdf = ring_kde() pdfs = [ - load_pdf(m, eid) - for m, eid in zip(models[1:], exp_ids[1:]) + load_pdf(m, eif, nc, lr) + for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) ] - pdfs = [truth_pdf] + pdfs - vmax = np.max(truth_pdf) + vmax = np.max([truth_pdf] + pdfs) vmin = 0.0 + metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) + os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True) - for idx, (p, m, eid) in enumerate(zip(pdfs, models, exp_ids)): + data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components)) + for idx, (p, pdf, m, nc) in enumerate(data_pdfs): setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0) fig, ax = plt.subplots(1, 1) - if eid: - if 'PC' in m: - num_components = int(eid.split('_')[2][1:]) - else: - num_components = int(eid.split('_')[0][1:]) - title = f"{format_model_name(m, num_components)}" + if args.title: + title = f"{format_model_name(m, nc)}" if p is not None else m else: - title = m + title = None - if idx == 0: - vmax = None - args.title = True + plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax) + if p is not None and args.show_ellipses: + plot_mixture_ellipses(p, ax=ax) - plot_pdf(p, vmin=vmin, vmax=vmax, ax=ax) + x_lim = metadata['domains'][0] + y_lim = metadata['domains'][1] + x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) + y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) + #lims = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + ax.set_xlim(*x_lim) + ax.set_ylim(*y_lim) ax.set_xticks([]) ax.set_yticks([]) ax.set_aspect(1.0) if args.title: ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center') - if idx == 0: - plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-gt.png'), dpi=1200) - else: - plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-{idx}.png'), dpi=1200) + filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png' + plt.savefig(os.path.join('figures', 'gaussian-ring', filename), dpi=1200)