-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #94 from normal-computing/imdb
Add IMDB example
- Loading branch information
Showing
21 changed files
with
1,500 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Cold posterior effect on IMDB data | ||
|
||
We investigate a wide range of `posteriors` methods for a [CNN-LSTM](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) model on the [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). | ||
We run each methods for a variety of temperatures, that is targetting the tempered | ||
posterior distribution $p(\theta \mid \mathcal{D})^{\frac{1}{T}}$ for $T \geq 0$, thus | ||
investigating the so-called [cold posterior effect](https://arxiv.org/abs/2008.05912), | ||
where improved predictive performance has been observed for $T<1$. | ||
|
||
We observe improved significantly improved performance of Bayesian techniques over | ||
gradient descent and notably that the cold posterior effect is significantly more | ||
prominent for Gaussian approximations than SGMCMC variants. | ||
|
||
## Methods | ||
|
||
We train the CNN-LSTM model with 2.7m parameters on 25k reviews from the IMDB dataset | ||
for binary classification of positive and negative reviews. We then use `posteriors` to | ||
swap between the following methods: | ||
|
||
- **map**: Maximum a posteriori (MAP) estimate of the parameters. Simple optimization | ||
using [AdamW](https://arxiv.org/abs/1711.05101) via `posteriors.torchopt`. | ||
- **laplace.diag_fisher**: Use the MAP estimate as the basis for a Laplace approximation | ||
using a diagonal empirical Fisher covariance matrix. Additionally use | ||
`posteriors.linearized_forward_diag` to asses a [linearized version](https://arxiv.org/abs/2106.14806) | ||
of the posterior predictive distribution. | ||
- **laplace.diag_ggn**: Same as above but using the diagonal Gauss-Newton Fisher matrix, | ||
which is [equivalent to the conditional Fisher information matrix](https://arxiv.org/abs/1412.1193) | ||
(integrated over labels rather than using empirical distribution). Again we assess | ||
traditional and linearized Laplace. | ||
- **vi.diag**: Variational inference using a diagonal Gaussian approximation. Optimized | ||
using [AdamW](https://arxiv.org/abs/1711.05101) and also linearized variant. | ||
- **sgmcmc.sghmc (Serial)**: [Stochastic gradient Hamiltonian Monte Carlo](https://arxiv.org/abs/1506.04696) | ||
(SGHMC) with a Monte Carlo approximation to the posterior collected by running a single | ||
trajectory (and removing a burn-in). | ||
- **sgmcmc.sghmc (Parallel)**: Same as above but running 15 trajectories in parallel and | ||
only collecting the final state of each trajectory. | ||
|
||
|
||
All methods are run for a variety of temperatures $T \in \{0.03, 0.1, 0.3, 1.0, 3.0\}$. | ||
Each train is over 30 epochs except serial SGHMC which uses 60 epochs to collect 27 | ||
samples from a single trajectory. Each method and temperature is run 5 times with | ||
different random seeds, except for the parallel SGHMC in which we run over 35 seeds and | ||
then bootstrap 5 ensembles of size 15. In all cases we use a diagonal Gaussian prior | ||
with all variances set to 1/40. | ||
|
||
|
||
## Results | ||
|
||
We plot the test loss for each method and temperature in Figure 1. | ||
|
||
<p align="center"> | ||
<img src="https://storage.googleapis.com/posteriors/cold_posterior_laplace_loss.png" width=30%"> | ||
<img src="https://storage.googleapis.com/posteriors/cold_posterior_vi_loss.png" width=30%"> | ||
<img src="https://storage.googleapis.com/posteriors/cold_posterior_sghmc_loss.png" width=30%"> | ||
<br> | ||
<em>Figure 1. Test loss on IMDB data (lower is better) for varying temperatures.</em> | ||
</p> | ||
|
||
There are a few takeaways from the results: | ||
- Bayesian methods (VI and SGMCMC) can significantly improve over gradient descent (MAP | ||
and we also trained for the MLE which severely overfits and was omitted from the plot | ||
for clarity). | ||
- Regular Laplace does not perform well, the linearization helps somewhat with GGN out | ||
performing Empirical Fisher. | ||
- In contrast, the linearization is detrimental for VI which we posit is due to the VI | ||
training acting in parameter space without knowledge of linearization. | ||
- A strong cold posterior effect for Gaussian methods (VI + Laplace), only very mild | ||
cold posterior effect for non-Gaussian Bayes methods (SGHMC). | ||
- Parallel SGHMC outperforms Serial SGHMC and also evidence that both out perform [deep | ||
ensemble](https://arxiv.org/abs/1612.01474) (which is obtained with parallel SGHMC and | ||
$T=0$). | ||
|
||
|
||
## Data | ||
|
||
We download the IMDB dataset using [keras.datasets.imdb](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data). | ||
|
||
|
||
## Model | ||
We use the CNN-LSTM model from [Wenzel et al](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) | ||
which consists of an embedding layer, a convolutional layer, ReLU activation, max-pooling layer, | ||
LSTM layer, and a final dense layer. In total there are 2.7m parameters. We use a diagonal | ||
Gaussian prior with all variances set to 1/40. | ||
|
||
|
||
## Code structure | ||
|
||
- `lstm.py`: Simple, custom code for an LSTM layer that composes with `torch.func`. | ||
- `model.py`: Specification of the CNN-LSTM model. | ||
- `data.py`: Functions to load the IMDB data using `keras`. | ||
- `train.py`: General script for training the model using `posteriors` methods | ||
which can easily be swapped. | ||
- `configs/` : Configuration files (python files) for the different methods. | ||
- `train_runner.sh`: Bash script to run a series of training jobs. | ||
- `combine_states_serial.py`: Combines parameter states from a single SGHMC trajectory | ||
for ensembled forward calls. | ||
- `combine_states_parallel.py`: Combines final parameter states from a multiple SGHMC | ||
trajectories for ensembled forward calls. | ||
- `test.py`: Calculate metrics on the IMDB test data for a given method. | ||
- `test_runner.sh`: Bash script to run a series of testing jobs. | ||
- `plot.py`: Generate Figure 1. | ||
- `utils.py`: Small utility functions for logging. | ||
|
||
Training code for a single train can be run from the root directory: | ||
```bash | ||
PYTHONPATH=. python examples/imdb/train.py --config examples/imdb/configs/laplace_diag_fisher.py --temperature 0.1 --seed 0 --device cuda:0 | ||
``` | ||
or by configuring multiple runs in `train_runner.sh` and running: | ||
```bash | ||
bash examples/imdb/train_runner.sh | ||
``` | ||
Similarly testing code can be run from the root directory: | ||
```bash | ||
PYTHONPATH=. python examples/imdb/test.py --config examples/imdb/configs/laplace_diag_fisher.py --seed 0 --device cuda:0 | ||
``` | ||
or by configuring settings in `test_runner.sh` and running: | ||
```bash | ||
bash examples/imdb/test_runner.sh | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pickle | ||
import torch | ||
import os | ||
from optree import tree_map | ||
import shutil | ||
|
||
|
||
temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] | ||
|
||
|
||
# SGHMC Parallel | ||
base = "examples/imdb/results/sghmc_parallel/sghmc_parallel" | ||
save_dir_base = "examples/imdb/results/sghmc_parallel" | ||
|
||
|
||
bootstrap_seeds = [ | ||
(torch.multinomial(torch.ones(35), 15, replacement=False) + 1).tolist() | ||
for _ in range(5) | ||
] | ||
|
||
|
||
for temp in temperatures: | ||
temp_str = str(temp).replace(".", "-") | ||
|
||
load_paths = [] | ||
|
||
for k, seeds in enumerate(bootstrap_seeds): | ||
for seed in seeds: | ||
spec_base = base | ||
spec_base += f"_seed{seed}" | ||
spec_base += f"_temp{temp_str}/" | ||
|
||
match_str = ( | ||
"state_" if seed is None else "state" | ||
) # don't need last save for serial | ||
|
||
load_paths += [ | ||
spec_base + file for file in os.listdir(spec_base) if match_str in file | ||
] | ||
|
||
save_dir = save_dir_base + f"_seed{k+1}" + f"_temp{temp_str}/" | ||
|
||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
shutil.copy(spec_base + "config.py", save_dir + "/config.py") | ||
|
||
with open(f"{save_dir}/paths.txt", "w") as f: | ||
f.write("\n".join(load_paths)) | ||
|
||
# Load states | ||
states = [pickle.load(open(d, "rb")) for d in load_paths] | ||
|
||
# Delete auxiliary info | ||
for s in states: | ||
del s.aux | ||
|
||
# Move states to cpu | ||
states = tree_map(lambda x: x.detach().to("cpu"), states) | ||
|
||
# Combine states | ||
combined_state = tree_map(lambda *x: torch.stack(x), *states) | ||
|
||
# Save state | ||
with open(f"{save_dir}state.pkl", "wb") as f: | ||
pickle.dump(combined_state, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pickle | ||
import torch | ||
import os | ||
from optree import tree_map | ||
import shutil | ||
|
||
|
||
temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] | ||
|
||
# SGHMC Serial | ||
bases = [ | ||
f"examples/imdb/results/sghmc_serial_seed{repeat_seed}" | ||
for repeat_seed in range(1, 6) | ||
] | ||
seeds = [None] | ||
|
||
|
||
for base in bases: | ||
for temp in temperatures: | ||
temp_str = str(temp).replace(".", "-") | ||
|
||
load_paths = [] | ||
|
||
for seed in seeds: | ||
spec_base = base | ||
|
||
if seed is not None: | ||
spec_base += f"_seed{seed}" | ||
|
||
spec_base += f"_temp{temp_str}/" | ||
|
||
match_str = ( | ||
"state_" if seed is None else "state" | ||
) # don't need last save for serial | ||
|
||
load_paths += [ | ||
spec_base + file for file in os.listdir(spec_base) if match_str in file | ||
] | ||
|
||
save_dir = base + f"_temp{temp_str}/" | ||
|
||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
shutil.copy(spec_base + "config.py", save_dir + "/config.py") | ||
|
||
with open(f"{save_dir}/paths.txt", "w") as f: | ||
f.write("\n".join(load_paths)) | ||
|
||
# Load states | ||
states = [pickle.load(open(d, "rb")) for d in load_paths] | ||
|
||
# Delete auxiliary info | ||
for s in states: | ||
del s.aux | ||
|
||
# Move states to cpu | ||
states = tree_map(lambda x: x.detach().to("cpu"), states) | ||
|
||
# Combine states | ||
combined_state = tree_map(lambda *x: torch.stack(x), *states) | ||
|
||
# Save state | ||
with open(f"{save_dir}state.pkl", "wb") as f: | ||
pickle.dump(combined_state, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import posteriors | ||
from optree import tree_map | ||
|
||
name = "laplace_fisher" | ||
save_dir = "examples/imdb/results/" + name | ||
params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params | ||
|
||
prior_sd = torch.inf | ||
batch_size = 32 | ||
burnin = None | ||
save_frequency = None | ||
|
||
method = posteriors.laplace.diag_fisher | ||
config_args = {} # arguments for method.build (aside from log_posterior) | ||
log_metrics = {} # dict containing names of metrics as keys and their paths in state as values | ||
display_metric = "loss" # metric to display in tqdm progress bar | ||
|
||
log_frequency = 100 # frequency at which to log metrics | ||
log_window = 30 # window size for moving average | ||
|
||
n_test_samples = 50 | ||
n_linearised_test_samples = 10000 | ||
epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag | ||
|
||
|
||
def to_sd_diag(state, temperature=1.0): | ||
return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) | ||
|
||
|
||
def forward(model, state, batch, temperature=1.0): | ||
x, _ = batch | ||
sd_diag = to_sd_diag(state, temperature) | ||
|
||
sampled_params = posteriors.diag_normal_sample( | ||
state.params, sd_diag, (n_test_samples,) | ||
) | ||
|
||
def model_func(p, x): | ||
return torch.func.functional_call(model, p, x) | ||
|
||
logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( | ||
0, 1 | ||
) | ||
return logits | ||
|
||
|
||
def forward_linearized(model, state, batch, temperature=1.0): | ||
x, _ = batch | ||
sd_diag = to_sd_diag(state, temperature) | ||
|
||
def model_func_with_aux(p, x): | ||
return torch.func.functional_call(model, p, x), torch.tensor([]) | ||
|
||
lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( | ||
model_func_with_aux, | ||
state.params, | ||
x, | ||
sd_diag, | ||
) | ||
|
||
samps = torch.randn( | ||
lin_mean.shape[0], | ||
n_linearised_test_samples, | ||
lin_mean.shape[1], | ||
device=lin_mean.device, | ||
) | ||
lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) | ||
return lin_logits | ||
|
||
|
||
forward_dict = {"Laplace EF": forward, "Laplace EF Linearized": forward_linearized} |
Oops, something went wrong.