diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 9dc73de5..2bc52ee5 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -32,6 +32,13 @@ on a series of books from the [pg19](https://huggingface.co/datasets/pg19) datas Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch). +[](https://github.com/normal-computing/posteriors/tree/main/examples/imdb) + +- [`imdb`](https://github.com/normal-computing/posteriors/tree/main/examples/imdb): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) for a range of approximate +Bayesian methods with a CNN-LSTM model on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data) +data, with some interesting takeaways. + + [](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb) - [`continual_regression`](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb): diff --git a/examples/README.md b/examples/README.md index 07c4b4c9..42ea63b3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,6 +4,9 @@ This directory contains examples of how to use the `posteriors` package. - [`continual_lora`](continual_lora/): Uses `posteriors.laplace.diag_fisher` to avoid catastrophic forgetting in fine-tuning [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) on a series of books from the [pg19](https://huggingface.co/datasets/pg19) dataset. +- [`imdb`](imdb/): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) +for a range of approximate Bayesian methods on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data) +data. - [`yelp`](yelp/): Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch). - [`continual_regression`](continual_regression.ipynb): [Variational continual learning](https://arxiv.org/abs/1710.10628) diff --git a/examples/imdb/README.md b/examples/imdb/README.md new file mode 100644 index 00000000..783f5a40 --- /dev/null +++ b/examples/imdb/README.md @@ -0,0 +1,119 @@ +# Cold posterior effect on IMDB data + +We investigate a wide range of `posteriors` methods for a [CNN-LSTM](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) model on the [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). +We run each methods for a variety of temperatures, that is targetting the tempered +posterior distribution $p(\theta \mid \mathcal{D})^{\frac{1}{T}}$ for $T \geq 0$, thus +investigating the so-called [cold posterior effect](https://arxiv.org/abs/2008.05912), +where improved predictive performance has been observed for $T<1$. + +We observe improved significantly improved performance of Bayesian techniques over +gradient descent and notably that the cold posterior effect is significantly more +prominent for Gaussian approximations than SGMCMC variants. + +## Methods + +We train the CNN-LSTM model with 2.7m parameters on 25k reviews from the IMDB dataset +for binary classification of positive and negative reviews. We then use `posteriors` to +swap between the following methods: + +- **map**: Maximum a posteriori (MAP) estimate of the parameters. Simple optimization +using [AdamW](https://arxiv.org/abs/1711.05101) via `posteriors.torchopt`. +- **laplace.diag_fisher**: Use the MAP estimate as the basis for a Laplace approximation +using a diagonal empirical Fisher covariance matrix. Additionally use +`posteriors.linearized_forward_diag` to asses a [linearized version](https://arxiv.org/abs/2106.14806) +of the posterior predictive distribution. +- **laplace.diag_ggn**: Same as above but using the diagonal Gauss-Newton Fisher matrix, +which is [equivalent to the conditional Fisher information matrix](https://arxiv.org/abs/1412.1193) +(integrated over labels rather than using empirical distribution). Again we assess +traditional and linearized Laplace. +- **vi.diag**: Variational inference using a diagonal Gaussian approximation. Optimized +using [AdamW](https://arxiv.org/abs/1711.05101) and also linearized variant. +- **sgmcmc.sghmc (Serial)**: [Stochastic gradient Hamiltonian Monte Carlo](https://arxiv.org/abs/1506.04696) +(SGHMC) with a Monte Carlo approximation to the posterior collected by running a single +trajectory (and removing a burn-in). +- **sgmcmc.sghmc (Parallel)**: Same as above but running 15 trajectories in parallel and +only collecting the final state of each trajectory. + + +All methods are run for a variety of temperatures $T \in \{0.03, 0.1, 0.3, 1.0, 3.0\}$. +Each train is over 30 epochs except serial SGHMC which uses 60 epochs to collect 27 +samples from a single trajectory. Each method and temperature is run 5 times with +different random seeds, except for the parallel SGHMC in which we run over 35 seeds and +then bootstrap 5 ensembles of size 15. In all cases we use a diagonal Gaussian prior +with all variances set to 1/40. + + +## Results + +We plot the test loss for each method and temperature in Figure 1. + +

+ + + +
+ Figure 1. Test loss on IMDB data (lower is better) for varying temperatures. +

+ +There are a few takeaways from the results: +- Bayesian methods (VI and SGMCMC) can significantly improve over gradient descent (MAP +and we also trained for the MLE which severely overfits and was omitted from the plot +for clarity). +- Regular Laplace does not perform well, the linearization helps somewhat with GGN out +performing Empirical Fisher. +- In contrast, the linearization is detrimental for VI which we posit is due to the VI +training acting in parameter space without knowledge of linearization. +- A strong cold posterior effect for Gaussian methods (VI + Laplace), only very mild +cold posterior effect for non-Gaussian Bayes methods (SGHMC). +- Parallel SGHMC outperforms Serial SGHMC and also evidence that both out perform [deep +ensemble](https://arxiv.org/abs/1612.01474) (which is obtained with parallel SGHMC and +$T=0$). + + +## Data + +We download the IMDB dataset using [keras.datasets.imdb](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data). + + +## Model +We use the CNN-LSTM model from [Wenzel et al](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) +which consists of an embedding layer, a convolutional layer, ReLU activation, max-pooling layer, +LSTM layer, and a final dense layer. In total there are 2.7m parameters. We use a diagonal +Gaussian prior with all variances set to 1/40. + + +## Code structure + +- `lstm.py`: Simple, custom code for an LSTM layer that composes with `torch.func`. +- `model.py`: Specification of the CNN-LSTM model. +- `data.py`: Functions to load the IMDB data using `keras`. +- `train.py`: General script for training the model using `posteriors` methods + which can easily be swapped. +- `configs/` : Configuration files (python files) for the different methods. +- `train_runner.sh`: Bash script to run a series of training jobs. +- `combine_states_serial.py`: Combines parameter states from a single SGHMC trajectory +for ensembled forward calls. +- `combine_states_parallel.py`: Combines final parameter states from a multiple SGHMC +trajectories for ensembled forward calls. +- `test.py`: Calculate metrics on the IMDB test data for a given method. +- `test_runner.sh`: Bash script to run a series of testing jobs. +- `plot.py`: Generate Figure 1. +- `utils.py`: Small utility functions for logging. + +Training code for a single train can be run from the root directory: +```bash +PYTHONPATH=. python examples/imdb/train.py --config examples/imdb/configs/laplace_diag_fisher.py --temperature 0.1 --seed 0 --device cuda:0 +``` +or by configuring multiple runs in `train_runner.sh` and running: +```bash +bash examples/imdb/train_runner.sh +``` +Similarly testing code can be run from the root directory: +```bash +PYTHONPATH=. python examples/imdb/test.py --config examples/imdb/configs/laplace_diag_fisher.py --seed 0 --device cuda:0 +``` +or by configuring settings in `test_runner.sh` and running: +```bash +bash examples/imdb/test_runner.sh +``` + diff --git a/examples/imdb/combine_states_parallel.py b/examples/imdb/combine_states_parallel.py new file mode 100644 index 00000000..29727e1e --- /dev/null +++ b/examples/imdb/combine_states_parallel.py @@ -0,0 +1,65 @@ +import pickle +import torch +import os +from optree import tree_map +import shutil + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] + + +# SGHMC Parallel +base = "examples/imdb/results/sghmc_parallel/sghmc_parallel" +save_dir_base = "examples/imdb/results/sghmc_parallel" + + +bootstrap_seeds = [ + (torch.multinomial(torch.ones(35), 15, replacement=False) + 1).tolist() + for _ in range(5) +] + + +for temp in temperatures: + temp_str = str(temp).replace(".", "-") + + load_paths = [] + + for k, seeds in enumerate(bootstrap_seeds): + for seed in seeds: + spec_base = base + spec_base += f"_seed{seed}" + spec_base += f"_temp{temp_str}/" + + match_str = ( + "state_" if seed is None else "state" + ) # don't need last save for serial + + load_paths += [ + spec_base + file for file in os.listdir(spec_base) if match_str in file + ] + + save_dir = save_dir_base + f"_seed{k+1}" + f"_temp{temp_str}/" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + shutil.copy(spec_base + "config.py", save_dir + "/config.py") + + with open(f"{save_dir}/paths.txt", "w") as f: + f.write("\n".join(load_paths)) + + # Load states + states = [pickle.load(open(d, "rb")) for d in load_paths] + + # Delete auxiliary info + for s in states: + del s.aux + + # Move states to cpu + states = tree_map(lambda x: x.detach().to("cpu"), states) + + # Combine states + combined_state = tree_map(lambda *x: torch.stack(x), *states) + + # Save state + with open(f"{save_dir}state.pkl", "wb") as f: + pickle.dump(combined_state, f) diff --git a/examples/imdb/combine_states_serial.py b/examples/imdb/combine_states_serial.py new file mode 100644 index 00000000..681ab23d --- /dev/null +++ b/examples/imdb/combine_states_serial.py @@ -0,0 +1,64 @@ +import pickle +import torch +import os +from optree import tree_map +import shutil + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] + +# SGHMC Serial +bases = [ + f"examples/imdb/results/sghmc_serial_seed{repeat_seed}" + for repeat_seed in range(1, 6) +] +seeds = [None] + + +for base in bases: + for temp in temperatures: + temp_str = str(temp).replace(".", "-") + + load_paths = [] + + for seed in seeds: + spec_base = base + + if seed is not None: + spec_base += f"_seed{seed}" + + spec_base += f"_temp{temp_str}/" + + match_str = ( + "state_" if seed is None else "state" + ) # don't need last save for serial + + load_paths += [ + spec_base + file for file in os.listdir(spec_base) if match_str in file + ] + + save_dir = base + f"_temp{temp_str}/" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + shutil.copy(spec_base + "config.py", save_dir + "/config.py") + + with open(f"{save_dir}/paths.txt", "w") as f: + f.write("\n".join(load_paths)) + + # Load states + states = [pickle.load(open(d, "rb")) for d in load_paths] + + # Delete auxiliary info + for s in states: + del s.aux + + # Move states to cpu + states = tree_map(lambda x: x.detach().to("cpu"), states) + + # Combine states + combined_state = tree_map(lambda *x: torch.stack(x), *states) + + # Save state + with open(f"{save_dir}state.pkl", "wb") as f: + pickle.dump(combined_state, f) diff --git a/examples/imdb/configs/laplace_fisher.py b/examples/imdb/configs/laplace_fisher.py new file mode 100644 index 00000000..40d42e6c --- /dev/null +++ b/examples/imdb/configs/laplace_fisher.py @@ -0,0 +1,72 @@ +import torch +import posteriors +from optree import tree_map + +name = "laplace_fisher" +save_dir = "examples/imdb/results/" + name +params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params + +prior_sd = torch.inf +batch_size = 32 +burnin = None +save_frequency = None + +method = posteriors.laplace.diag_fisher +config_args = {} # arguments for method.build (aside from log_posterior) +log_metrics = {} # dict containing names of metrics as keys and their paths in state as values +display_metric = "loss" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + +n_test_samples = 50 +n_linearised_test_samples = 10000 +epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag + + +def to_sd_diag(state, temperature=1.0): + return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) + + +def forward(model, state, batch, temperature=1.0): + x, _ = batch + sd_diag = to_sd_diag(state, temperature) + + sampled_params = posteriors.diag_normal_sample( + state.params, sd_diag, (n_test_samples,) + ) + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( + 0, 1 + ) + return logits + + +def forward_linearized(model, state, batch, temperature=1.0): + x, _ = batch + sd_diag = to_sd_diag(state, temperature) + + def model_func_with_aux(p, x): + return torch.func.functional_call(model, p, x), torch.tensor([]) + + lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( + model_func_with_aux, + state.params, + x, + sd_diag, + ) + + samps = torch.randn( + lin_mean.shape[0], + n_linearised_test_samples, + lin_mean.shape[1], + device=lin_mean.device, + ) + lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) + return lin_logits + + +forward_dict = {"Laplace EF": forward, "Laplace EF Linearized": forward_linearized} diff --git a/examples/imdb/configs/laplace_ggn.py b/examples/imdb/configs/laplace_ggn.py new file mode 100644 index 00000000..f06760ea --- /dev/null +++ b/examples/imdb/configs/laplace_ggn.py @@ -0,0 +1,72 @@ +import torch +import posteriors +from optree import tree_map + +name = "laplace_ggn" +save_dir = "examples/imdb/results/" + name +params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params + +prior_sd = torch.inf +batch_size = 32 +burnin = None +save_frequency = None + +method = posteriors.laplace.diag_ggn +config_args = {} # arguments for method.build (aside from log_posterior) +log_metrics = {} # dict containing names of metrics as keys and their paths in state as values +display_metric = "loss" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + +n_test_samples = 50 +n_linearised_test_samples = 10000 +epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag + + +def to_sd_diag(state, temperature=1.0): + return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) + + +def forward(model, state, batch, temperature=1.0): + x, _ = batch + sd_diag = to_sd_diag(state, temperature) + + sampled_params = posteriors.diag_normal_sample( + state.params, sd_diag, (n_test_samples,) + ) + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( + 0, 1 + ) + return logits + + +def forward_linearized(model, state, batch, temperature=1.0): + x, _ = batch + sd_diag = to_sd_diag(state, temperature) + + def model_func_with_aux(p, x): + return torch.func.functional_call(model, p, x), torch.tensor([]) + + lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( + model_func_with_aux, + state.params, + x, + sd_diag, + ) + + samps = torch.randn( + lin_mean.shape[0], + n_linearised_test_samples, + lin_mean.shape[1], + device=lin_mean.device, + ) + lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) + return lin_logits + + +forward_dict = {"Laplace GGN": forward, "Laplace GGN Linearized": forward_linearized} diff --git a/examples/imdb/configs/map.py b/examples/imdb/configs/map.py new file mode 100644 index 00000000..02813e11 --- /dev/null +++ b/examples/imdb/configs/map.py @@ -0,0 +1,36 @@ +import posteriors +import torchopt +import torch + +name = "map" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = None +save_frequency = None + + +method = posteriors.torchopt +config_args = { + "optimizer": torchopt.adamw(lr=1e-3, maximize=True) +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "loss", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + logits = torch.func.functional_call(model, state.params, x) + return logits.unsqueeze(1) + + +forward_dict = {"MAP": forward} diff --git a/examples/imdb/configs/mle.py b/examples/imdb/configs/mle.py new file mode 100644 index 00000000..bfb9e5a1 --- /dev/null +++ b/examples/imdb/configs/mle.py @@ -0,0 +1,36 @@ +import posteriors +import torchopt +import torch + +name = "mle" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = torch.inf +batch_size = 32 +burnin = None +save_frequency = None + + +method = posteriors.torchopt +config_args = { + "optimizer": torchopt.adamw(lr=1e-3, maximize=True) +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "loss", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + logits = torch.func.functional_call(model, state.params, x) + return logits.unsqueeze(1) + + +forward_dict = {"MLE": forward} diff --git a/examples/imdb/configs/sghmc_parallel.py b/examples/imdb/configs/sghmc_parallel.py new file mode 100644 index 00000000..e7ed818c --- /dev/null +++ b/examples/imdb/configs/sghmc_parallel.py @@ -0,0 +1,42 @@ +import posteriors +import torch + +name = "sghmc_parallel" +save_dir = "examples/imdb/results/sghmc_parallel/" + name +params_dir = None # directory to load state containing initialisation params + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = None +save_frequency = None + +lr = 1e-1 + +method = posteriors.sgmcmc.sghmc +config_args = { + "lr": lr, + "alpha": 1.0, + "beta": 0.0, + "temperature": None, # None temperature gets set by train.py + "momenta": 0.0, +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "log_posterior", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(state.params, x).transpose(0, 1) + return logits + + +forward_dict = {"Parallel SGHMC": forward} diff --git a/examples/imdb/configs/sghmc_serial.py b/examples/imdb/configs/sghmc_serial.py new file mode 100644 index 00000000..aec7e4e8 --- /dev/null +++ b/examples/imdb/configs/sghmc_serial.py @@ -0,0 +1,43 @@ +import posteriors +import torch + +name = "sghmc_serial" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = 20000 +save_frequency = 1000 + +lr = 1e-1 + +method = posteriors.sgmcmc.sghmc +config_args = { + "lr": lr, + "alpha": 1.0, + "beta": 0.0, + "temperature": None, # None temperature gets set by train.py + "momenta": 0.0, +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "log_posterior", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(state.params, x).transpose(0, 1) + return logits + + +forward_dict = {"Serial SGHMC": forward} diff --git a/examples/imdb/configs/vi.py b/examples/imdb/configs/vi.py new file mode 100644 index 00000000..082ddadb --- /dev/null +++ b/examples/imdb/configs/vi.py @@ -0,0 +1,81 @@ +import posteriors +import torchopt +import torch +from optree import tree_map + +name = "vi" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = None +save_frequency = None + +method = posteriors.vi.diag +config_args = { + "optimizer": torchopt.adamw(lr=1e-3), + "temperature": None, # None temperature gets set by train.py + "n_samples": 1, + "stl": True, + "init_log_sds": -3, +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "nelbo": "nelbo", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "nelbo" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + +n_test_samples = 50 +n_linearised_test_samples = 10000 + + +def to_sd_diag(state): + return tree_map(lambda x: x.exp(), state.log_sd_diag) + + +def forward(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + sampled_params = posteriors.diag_normal_sample( + state.params, sd_diag, (n_test_samples,) + ) + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( + 0, 1 + ) + return logits + + +def forward_linearized(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + def model_func_with_aux(p, x): + return torch.func.functional_call(model, p, x), torch.tensor([]) + + lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( + model_func_with_aux, + state.params, + x, + sd_diag, + ) + + samps = torch.randn( + lin_mean.shape[0], + n_linearised_test_samples, + lin_mean.shape[1], + device=lin_mean.device, + ) + lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) + return lin_logits + + +forward_dict = {"VI": forward, "VI Linearized": forward_linearized} diff --git a/examples/imdb/data.py b/examples/imdb/data.py new file mode 100644 index 00000000..339655c5 --- /dev/null +++ b/examples/imdb/data.py @@ -0,0 +1,27 @@ +import torch +from torch.utils.data import DataLoader, TensorDataset +from keras.datasets import imdb +from keras.preprocessing.sequence import pad_sequences + + +def load_imdb_dataset(batch_size=32, max_features=20000, max_len=100): + # Load and pad IMDB dataset + (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) + x_train = pad_sequences(x_train, maxlen=max_len) + x_test = pad_sequences(x_test, maxlen=max_len) + + # Convert data to PyTorch tensors + train_data = torch.tensor(x_train, dtype=torch.long) + train_labels = torch.tensor(y_train, dtype=torch.long) + test_data = torch.tensor(x_test, dtype=torch.long) + test_labels = torch.tensor(y_test, dtype=torch.long) + + # Create Tensor datasets + train_dataset = TensorDataset(train_data, train_labels) + test_dataset = TensorDataset(test_data, test_labels) + + # Create DataLoaders + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + test_dataloader = DataLoader(test_dataset, batch_size=batch_size) + + return train_dataloader, test_dataloader diff --git a/examples/imdb/lstm.py b/examples/imdb/lstm.py new file mode 100644 index 00000000..7804cbc4 --- /dev/null +++ b/examples/imdb/lstm.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn + + +class CustomLSTMCell(nn.Module): + def __init__(self, input_size, hidden_size): + super(CustomLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size) + self.input_gate = nn.Linear(input_size + hidden_size, hidden_size) + self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size) + self.output_gate = nn.Linear(input_size + hidden_size, hidden_size) + + def forward(self, input, states): + h_prev, c_prev = states + combined = torch.cat((input, h_prev), 1) + + f_t = torch.sigmoid(self.forget_gate(combined)) + i_t = torch.sigmoid(self.input_gate(combined)) + c_tilde = torch.tanh(self.cell_gate(combined)) + c_t = f_t * c_prev + i_t * c_tilde + o_t = torch.sigmoid(self.output_gate(combined)) + h_t = o_t * torch.tanh(c_t) + + return h_t, c_t + + +class CustomLSTM(nn.Module): + def __init__(self, input_size, hidden_size, batch_first=False): + super(CustomLSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.batch_first = batch_first + self.lstm_cell = CustomLSTMCell(input_size, hidden_size) + + def forward(self, input, initial_states=None): + if initial_states is None: + initial_h = torch.zeros( + 1, + input.size(0 if self.batch_first else 1), + self.hidden_size, + device=input.device, + ) + initial_c = torch.zeros( + 1, + input.size(0 if self.batch_first else 1), + self.hidden_size, + device=input.device, + ) + else: + initial_h, initial_c = initial_states + + # Ensure we are working with single layer, single direction states + initial_h = initial_h.squeeze(0) + initial_c = initial_c.squeeze(0) + + if self.batch_first: + input = input.transpose( + 0, 1 + ) # Convert (batch, seq_len, feature) to (seq_len, batch, feature) + + outputs = [] + h_t, c_t = initial_h, initial_c + + for i in range( + input.shape[0] + ): # input is expected to be (seq_len, batch, input_size) + h_t, c_t = self.lstm_cell(input[i], (h_t, c_t)) + outputs.append(h_t.unsqueeze(0)) + + outputs = torch.cat(outputs, 0) + + if self.batch_first: + outputs = outputs.transpose( + 0, 1 + ) # Convert back to (batch, seq_len, feature) + + return outputs, (h_t, c_t) + + +# Test equivalence +def test_lstm_equivalence(): + input_size = 10 + hidden_size = 20 + seq_len = 5 + batch_size = 1 + + # Initialize both LSTMs + torch_lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + custom_lstm = CustomLSTM(input_size, hidden_size, batch_first=True) + + # Manually setting the same weights and biases + with torch.no_grad(): + # Copy weights and biases from torch LSTM to custom LSTM + gates = ["input_gate", "forget_gate", "cell_gate", "output_gate"] + + for idx, gate in enumerate(gates): + start = idx * hidden_size + end = (idx + 1) * hidden_size + + getattr(custom_lstm.lstm_cell, gate).weight.data[:, :input_size].copy_( + torch_lstm.weight_ih_l0[start:end] + ) + getattr(custom_lstm.lstm_cell, gate).weight.data[:, input_size:].copy_( + torch_lstm.weight_hh_l0[start:end] + ) + getattr(custom_lstm.lstm_cell, gate).bias.data.copy_( + torch_lstm.bias_ih_l0[start:end] + torch_lstm.bias_hh_l0[start:end] + ) + # Dummy input + inputs = torch.randn(batch_size, seq_len, input_size) + + # Custom LSTM forward pass + custom_outputs, (custom_hn, custom_cn) = custom_lstm(inputs) + + # Torch LSTM forward pass + torch_outputs, (torch_hn, torch_cn) = torch_lstm(inputs) + + # Check outputs and final hidden and cell states + assert torch.allclose(custom_outputs, torch_outputs, atol=1e-6), "Output mismatch" + assert torch.allclose(custom_hn, torch_hn), "Hidden state mismatch" + assert torch.allclose(custom_cn, torch_cn), "Cell state mismatch" + + print("Test passed: Custom LSTM and torch.nn.LSTM outputs are equivalent!") + + +if __name__ == "__main__": + test_lstm_equivalence() diff --git a/examples/imdb/model.py b/examples/imdb/model.py new file mode 100644 index 00000000..651d8894 --- /dev/null +++ b/examples/imdb/model.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# torch.nn.LSTM does not work well with torch.func https://github.com/pytorch/pytorch/issues/105982 +# so use custom simplfied LSTM instead +from examples.imdb.lstm import CustomLSTM + + +class CNNLSTM(nn.Module): + def __init__( + self, + num_classes, + max_features=20000, + embedding_size=128, + cell_size=128, + num_filters=64, + kernel_size=5, + pool_size=4, + use_swish=False, + use_maxpool=True, + ): + super(CNNLSTM, self).__init__() + self.embedding = nn.Embedding( + num_embeddings=max_features, embedding_dim=embedding_size + ) + self.conv1d = nn.Conv1d( + in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + ) + self.use_swish = use_swish + self.use_maxpool = use_maxpool + if use_maxpool: + self.maxpool = nn.MaxPool1d(kernel_size=pool_size, stride=pool_size) + self.lstm = CustomLSTM( + input_size=num_filters, hidden_size=cell_size, batch_first=True + ) + + self.fc = nn.Linear(in_features=cell_size, out_features=num_classes) + + def forward(self, x): + # Embedding + x = self.embedding(x) # Shape: [batch_size, seq_length, embedding_size] + x = x.permute( + 0, 2, 1 + ) # Shape: [batch_size, embedding_size, seq_length] to match Conv1d input + + # Convolution + x = self.conv1d(x) + if self.use_swish: + x = x * torch.sigmoid(x) # Swish activation function + else: + x = F.relu(x) + + # Pooling + if self.use_maxpool: + x = self.maxpool(x) + + # Reshape for LSTM + x = x.permute(0, 2, 1) # Shape: [batch_size, seq_length, num_filters] + + # LSTM + output, _ = self.lstm(x) + + # Take the last sequence output + last_output = output[:, -1, :] + + # Fully connected layer + logits = self.fc(last_output) + + return logits diff --git a/examples/imdb/plot.py b/examples/imdb/plot.py new file mode 100644 index 00000000..1c8eeec1 --- /dev/null +++ b/examples/imdb/plot.py @@ -0,0 +1,204 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.ticker import FormatStrFormatter +import json +import os + +plt.rcParams["font.family"] = "Times New Roman" + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] +temp_strs = [str(temp).replace(".", "-") for temp in temperatures] +seeds = list(range(1, 6)) + + +# [base directory, tempered bool, colors] +bases = [ + # ["examples/imdb/results/mle", False, ["grey"]], + ["examples/imdb/results/map", False, ["grey"]], + ["examples/imdb/results/laplace_fisher", True, ["royalblue", "deepskyblue"]], + ["examples/imdb/results/laplace_ggn", True, ["purple", "mediumvioletred"]], +] + +# bases = [ +# ["examples/imdb/results/mle", False, ["grey"]], +# ["examples/imdb/results/map", False, ["grey"]], +# ["examples/imdb/results/vi", True, ["forestgreen", "darkkhaki"]], +# ] + +# bases = [ +# # ["examples/imdb/results/mle", False, ["grey"]], +# ["examples/imdb/results/map", False, ["grey"]], +# ["examples/imdb/results/sghmc_serial", True, ["firebrick"]], +# ["examples/imdb/results/sghmc_parallel", True, ["tomato"]], +# ] + +with_mle = "mle" in bases[0][0] + +if with_mle: + ylims = {"loss": (0.305, 1.2), "accuracy": (0.47, 0.88)} +else: + ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)} + +save_name = bases[-1][0].split("/")[-1].split("_")[0] + + +test_dicts = {} +colour_dict = {} + +for base, tempered, colours in bases: + for seed in seeds: + seed_base = base + "_seed" + str(seed) if seed is not None else base + match_str = seed_base + "_temp" + temp_strs[0] if tempered else seed_base + versions = [file for file in os.listdir(match_str) if "test" in file] + versions = [v.strip(".json").split("_")[1] for v in versions] + versions.sort() + + for k, name in enumerate(versions): + colour_dict[name] = colours[k] + if tempered: + for temp in temperatures: + if temp not in test_dicts: + key = f"{name}_temp{str(temp).replace('.', '-')}_seed{seed}" + test_dicts[key] = json.load( + open( + f"{seed_base}_temp{str(temp).replace('.', '-')}/test_{name}.json" + ) + ) + else: + key = f"{name}_tempNA_seed{seed}" + test_dicts[key] = json.load(open(f"{seed_base}/test_{name}.json")) + + +line_styles = ["--", ":", "-.", "-"] + + +def sorted_set(l_in): + return sorted(set(l_in), key=l_in.index) + + +method_names = sorted_set([key.split("_")[0] for key in test_dicts.keys()]) + + +def metric_vals_dict(metric_name): + md = {} + + for method_name in method_names: + method_name_keys = [ + key for key in test_dicts.keys() if key.split("_")[0] == method_name + ] + + temperature_keys = sorted_set([key.split("_")[1] for key in method_name_keys]) + + method_dict = {} + + for tk in temperature_keys: + temp = tk.strip("temp").replace("-", ".") + + method_dict[temp] = [ + np.mean(test_dicts[key][metric_name]) + for key in method_name_keys + if key.split("_")[1] == tk + ] + + md |= {method_name: method_dict} + return md + + +metrics = ["loss", "accuracy"] +metric_dicts = {metric_name: metric_vals_dict(metric_name) for metric_name in metrics} + + +def plot_metric(metric_name): + metric_dict = metric_dicts[metric_name] + + fig, ax = plt.subplots() + k = 0 + + linewidth = 2 if with_mle else 3 + + for method_name, method_dict in metric_dict.items(): + if len(method_dict) == 1: + mn = np.mean(method_dict["NA"]) + sds = np.std(method_dict["NA"]) + ax.fill_between( + [0.0, temperatures[-1] * 10], + [mn - sds] * 2, + [mn + sds] * 2, + color=colour_dict[method_name], + alpha=0.2, + zorder=0, + ) + ax.axhline( + np.mean(method_dict["NA"]), + label=method_name, + c=colour_dict[method_name], + linestyle=line_styles[k], + linewidth=linewidth, + zorder=1, + ) + + k += 1 + + else: + mns = np.array([np.mean(method_dict[temp]) for temp in method_dict.keys()]) + sds = np.array([np.std(method_dict[temp]) for temp in method_dict.keys()]) + + ax.fill_between( + temperatures, + mns - sds, + mns + sds, + color=colour_dict[method_name], + alpha=0.2, + zorder=0, + ) + ax.plot( + temperatures, + mns, + label=method_name, + marker="o", + markersize=10, + c=colour_dict[method_name], + linewidth=linewidth, + zorder=1, + ) + + if metric_name in ylims: + ax.set_ylim(ylims[metric_name]) + + ax.set_xlim(temperatures[0] * 0.9, temperatures[-1] * 1.1) + + ax.set_xscale("log") + + fontsize = 18 + + ax.set_xticks(temperatures) + ax.tick_params(axis="both", which="major", labelsize=fontsize * 0.75) + + # Change xtick labels to scalar format + ax.xaxis.set_major_formatter(FormatStrFormatter("%g")) + + ax.set_xlabel("Temperature", fontsize=fontsize) + ax.set_ylabel("Test " + metric_name.title(), fontsize=fontsize) + + leg_fontsize = fontsize * 0.75 if with_mle else fontsize + ax.legend( + frameon=True, + framealpha=1.0, + facecolor="white", + edgecolor="white", + fontsize=leg_fontsize, + ) + fig.tight_layout() + + save_dir = ( + f"examples/imdb/figures/{save_name}_{metric_name}_with_mle.png" + if with_mle + else f"examples/imdb/figures/{save_name}_{metric_name}.png" + ) + fig.savefig(save_dir, dpi=400) + plt.close() + + +plot_metric("loss") +plot_metric("accuracy") diff --git a/examples/imdb/test.py b/examples/imdb/test.py new file mode 100644 index 00000000..65aa1d0c --- /dev/null +++ b/examples/imdb/test.py @@ -0,0 +1,117 @@ +import argparse +import pickle +import importlib +from tqdm import tqdm +import torch +from torch.distributions import Categorical +from optree import tree_map +import os + +from examples.imdb.model import CNNLSTM +from examples.imdb.data import load_imdb_dataset +from examples.imdb.utils import log_metrics + +# Get config path and device from user +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str) +parser.add_argument("--device", default="cpu", type=str) +parser.add_argument("--seed", default=42, type=int) +parser.add_argument("--temperature", default=None, type=float) +args = parser.parse_args() + + +torch.manual_seed(args.seed) + +# Import configuration +config = importlib.import_module(args.config.replace("/", ".").replace(".py", "")) +config_dir = os.path.dirname(args.config) +save_dir = ( + config_dir + if args.temperature is None + else config_dir + f"_temp{str(args.temperature).replace('.', '-')}" +) +if not os.path.exists(save_dir): + os.makedirs(save_dir) + +# Load data +_, test_dataloader = load_imdb_dataset() + +# Load model +model = CNNLSTM(num_classes=2) +model.to(args.device) + +# Load state +state = pickle.load(open(config_dir + "/state.pkl", "rb")) +state = tree_map(lambda x: x.to(args.device), state) + +# Dictionary containing forward functions +forward_dict = config.forward_dict + + +def test_metrics(logits, labels): + probs = torch.nn.functional.softmax(logits, dim=-1) + probs = torch.where(probs < 1e-5, 1e-5, probs) + probs /= probs.sum(dim=-1, keepdim=True) + expected_probs = probs.mean(dim=1) + + logits = torch.log(probs) + expected_logits = torch.log(expected_probs) + + loss = -Categorical(logits=expected_logits).log_prob(labels) + accuracy = (expected_probs.argmax(dim=-1) == labels).float() + + total_uncertainty = -(torch.log(expected_probs) * expected_probs).mean(1) + aleatoric_uncertainty = -(torch.log(probs) * probs).mean(2).mean(1) + epistemic_uncertainty = total_uncertainty - aleatoric_uncertainty + + return { + "loss": loss, + "accuracy": accuracy, + "total_uncertainty": total_uncertainty, + "aleatoric_uncertainty": aleatoric_uncertainty, + "epistemic_uncertainty": epistemic_uncertainty, + } + + +# Run through test data +num_batches = len(test_dataloader) +metrics = [ + "loss", + "accuracy", + "total_uncertainty", + "aleatoric_uncertainty", + "epistemic_uncertainty", +] + +log_dict_forward = {k: {m: [] for m in metrics} for k in forward_dict.keys()} + +for batch in tqdm(test_dataloader): + with torch.no_grad(): + batch = tree_map(lambda x: x.to(args.device), batch) + labels = batch[1] + + forward_logits = { + k: v(model, state, batch) + if args.temperature is None + else v(model, state, batch, args.temperature) + for k, v in forward_dict.items() + } + + forward_metrics = { + k: test_metrics(logits, labels) for k, logits in forward_logits.items() + } + + for forward_k in log_dict_forward.keys(): + for metric_k in metrics: + log_dict_forward[forward_k][metric_k] += forward_metrics[forward_k][ + metric_k + ].tolist() + + +for forward_k, metric_dict in log_dict_forward.items(): + log_metrics( + metric_dict, + save_dir, + file_name="test_" + forward_k, + plot=False, + ) diff --git a/examples/imdb/test_runner.sh b/examples/imdb/test_runner.sh new file mode 100644 index 00000000..3fefc863 --- /dev/null +++ b/examples/imdb/test_runner.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +device="cuda:1" + +# List of temperature params +temperatures=(0.03 0.1 0.3 1.0 3.0) +seeds=$(seq 1 5) + +# Base directories +base_configs=( + # "examples/imdb/results/vi" + # "examples/imdb/results/sghmc_serial" + "examples/imdb/results/sghmc_parallel" +) + +for temperature in "${temperatures[@]}"; do + for seed in $seeds; do + for base_config in "${base_configs[@]}"; do + config_file="${base_config}_seed${seed}_temp${temperature//./-}/config.py" + echo "Executing $config_file" + PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ + --device $device + done + done +done + + + +# # Laplace + +# temperatures=(0.03 0.1 0.3 1.0 3.0) +# seeds=$(seq 2) + + +# # Base directories +# base_configs=( +# "examples/imdb/results/laplace_fisher" +# # "examples/imdb/results/laplace_ggn" +# ) + +# for temperature in "${temperatures[@]}"; do +# for seed in $seeds; do +# for base_config in "${base_configs[@]}"; do +# config_file="${base_config}_seed${seed}/config.py" +# echo "Executing $config_file" +# PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ +# --device $device --temperature $temperature +# done +# done +# done + + +# # MLE and MAP +# # List of temperature params +# seeds=$(seq 1 5) + +# # Base directories +# base_configs=( +# "examples/imdb/results/mle" +# "examples/imdb/results/map" +# ) + +# for seed in $seeds; do +# for base_config in "${base_configs[@]}"; do +# config_file="${base_config}_seed${seed}/config.py" +# echo "Executing $config_file" +# PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ +# --device $device +# done +# done \ No newline at end of file diff --git a/examples/imdb/train.py b/examples/imdb/train.py new file mode 100644 index 00000000..1382ad85 --- /dev/null +++ b/examples/imdb/train.py @@ -0,0 +1,149 @@ +import os +import argparse +import pickle +import importlib +from tqdm import tqdm +from optree import tree_map +import torch +import posteriors + +from examples.imdb.model import CNNLSTM +from examples.imdb.data import load_imdb_dataset +from examples.imdb import utils + + +# Get args from user +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str) +parser.add_argument("--device", default="cpu", type=str) +parser.add_argument("--seed", default=42, type=int) +parser.add_argument("--epochs", default=30, type=int) +parser.add_argument("--temperature", default=1.0, type=float) +args = parser.parse_args() + + +# Import configuration +config = importlib.import_module(args.config.replace("/", ".").replace(".py", "")) + + +# Set seed +if args.seed != 42: + config.save_dir += f"_seed{args.seed}" + if config.params_dir is not None: + config.params_dir += f"_seed{args.seed}/state.pkl" +else: + args.seed = 42 + config.params_dir += "/state.pkl" +torch.manual_seed(args.seed) + +# Load model +model = CNNLSTM(num_classes=2) +model.to(args.device) + +if config.params_dir is not None: + with open(config.params_dir, "rb") as f: + print(config.params_dir) + state = pickle.load(f) + model.load_state_dict(state.params) + +# Load data +train_dataloader, test_dataloader = load_imdb_dataset(batch_size=config.batch_size) +num_data = len(train_dataloader.dataset) + + +# Set temperature +if "temperature" in config.config_args and config.config_args["temperature"] is None: + config.config_args["temperature"] = args.temperature / num_data + temp_str = str(args.temperature).replace(".", "-") + config.save_dir += f"_temp{temp_str}" + + +# Create save directory if it does not exist +if not os.path.exists(config.save_dir): + os.makedirs(config.save_dir) + +# Save config +utils.save_config(args, config.save_dir) +print(f"Config saved to {config.save_dir}") + +# Extract model parameters +params = dict(model.named_parameters()) +num_params = posteriors.tree_size(params).item() +print(f"Number of parameters: {num_params/1e6:.3f}M") + + +# Define log posterior +def forward(p, batch): + x, y = batch + logits = torch.func.functional_call(model, p, x) + loss = torch.nn.functional.cross_entropy(logits, y) + return logits, (loss, logits) + + +def outer_log_lik(logits, batch): + _, y = batch + return -torch.nn.functional.cross_entropy(logits, y, reduction="sum") + + +def log_posterior(p, batch): + x, y = batch + logits = torch.func.functional_call(model, p, x) + loss = torch.nn.functional.cross_entropy(logits, y) + log_post = ( + -loss + + posteriors.diag_normal_log_prob(p, sd_diag=config.prior_sd, normalize=False) + / num_data + ) + return log_post, (loss, logits) + + +# Build transform +if config.method == posteriors.laplace.diag_ggn: + transform = config.method.build(forward, outer_log_lik, **config.config_args) +else: + transform = config.method.build(log_posterior, **config.config_args) + +# Initialize state +state = transform.init(params) + +# Train +i = j = 0 +num_batches = len(train_dataloader) +log_dict = {k: [] for k in config.log_metrics.keys()} | {"loss": []} +log_bar = tqdm(total=0, position=1, bar_format="{desc}") +for epoch in range(args.epochs): + for batch in tqdm( + train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0 + ): + batch = tree_map(lambda x: x.to(args.device), batch) + state = transform.update(state, batch) + + # Update metrics + log_dict = utils.append_metrics(log_dict, state, config.log_metrics) + log_bar.set_description_str( + f"{config.display_metric}: {log_dict[config.display_metric][-1]:.2f}" + ) + + # Log + i += 1 + if i % config.log_frequency == 0 or i % num_batches == 0: + utils.log_metrics( + log_dict, + config.save_dir, + window=config.log_window, + file_name="training", + ) + + # Save sequential state if desired + if ( + config.save_frequency is not None + and (i - config.burnin) >= 0 + and (i - config.burnin) % config.save_frequency == 0 + ): + with open(f"{config.save_dir}/state_{j}.pkl", "wb") as f: + pickle.dump(state, f) + j += 1 + +# Save final state +with open(f"{config.save_dir}/state.pkl", "wb") as f: + pickle.dump(state, f) diff --git a/examples/imdb/train_runner.sh b/examples/imdb/train_runner.sh new file mode 100644 index 00000000..a9772f95 --- /dev/null +++ b/examples/imdb/train_runner.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +device="cuda:0" + +# temperatures=(0.03 0.1 0.3 1.0 3.0) +# config="examples/imdb/configs/vi.py" +# epochs=30 +# seeds=$(seq 1 5) + +# temperatures=(0.03 0.1 0.3 1.0 3.0) +# config="examples/imdb/configs/sghmc_serial.py" +# epochs=60 +# seeds=$(seq 1 5) + +temperatures=(0.03 0.1 0.3 1.0 3.0) +config="examples/imdb/configs/sghmc_parallel.py" +epochs=30 +seeds=$(seq 1 35) + +# temperatures=(1.0) +# config="examples/imdb/configs/map.py" +# epochs=30 +# seeds=$(seq 1 5) + +# temperatures=(1.0) +# config="examples/imdb/configs/mle.py" +# epochs=30 +# seeds=$(seq 1 5) + +# temperatures=(1.0) +# config="examples/imdb/configs/laplace_fisher.py" +# epochs=1 +# seeds=$(seq 1 5) + +# temperatures=(1.0) +# config="examples/imdb/configs/laplace_ggn.py" +# epochs=1 +# seeds=$(seq 1 5) + +for seed in $seeds +do + for temp in "${temperatures[@]}" + do + PYTHONPATH=. python examples/imdb/train.py --config $config --device $device \ + --temperature $temp --epochs $epochs --seed $seed + done +done \ No newline at end of file diff --git a/examples/imdb/utils.py b/examples/imdb/utils.py new file mode 100644 index 00000000..b5b14813 --- /dev/null +++ b/examples/imdb/utils.py @@ -0,0 +1,45 @@ +import json +import matplotlib.pyplot as plt +import numpy as np +import shutil + + +# Save config +def save_config(args, save_dir): + shutil.copy(args.config, save_dir + "/config.py") + json.dump(vars(args), open(f"{save_dir}/args.json", "w")) + + +# Function to calculate moving average and accompanying x-axis values +def moving_average(x, w): + w_check = 1 if w >= len(x) else w + y = np.convolve(x, np.ones(w_check), "valid") / w_check + return range(w_check - 1, len(x)), y + + +# Function to log and plot metrics +def log_metrics(log_dict, save_dir, window=1, file_name="metrics", plot=True): + save_file = f"{save_dir}/{file_name}.json" + + with open(save_file, "w") as f: + json.dump(log_dict, f) + + if plot: + for k, v in log_dict.items(): + fig, ax = plt.subplots() + x, y = moving_average(v, window) + ax.plot(x, y, label=k, alpha=0.7) + ax.legend() + ax.set_xlabel("Iteration") + fig.tight_layout() + fig.savefig(f"{save_dir}/{file_name}_{k}.png", dpi=200) + plt.close(fig) + + +# Function to append metrics to log_dict +def append_metrics(log_dict, state, config_dict): + for k, v in config_dict.items(): + log_dict[k].append(getattr(state, v).mean().item()) + + log_dict["loss"].append(state.aux[0].mean().item()) + return log_dict