Skip to content

Commit

Permalink
Merge pull request #94 from normal-computing/imdb
Browse files Browse the repository at this point in the history
Add IMDB example
  • Loading branch information
SamDuffield authored May 22, 2024
2 parents 8202a93 + 6206689 commit b4c5c0f
Show file tree
Hide file tree
Showing 21 changed files with 1,500 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ on a series of books from the [pg19](https://huggingface.co/datasets/pg19) datas
Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a
sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).

[<img style="float: right; width: 6em" src="https://storage.googleapis.com/posteriors/cold_posterior_sghmc_loss.png">](https://github.com/normal-computing/posteriors/tree/main/examples/imdb)

- [`imdb`](https://github.com/normal-computing/posteriors/tree/main/examples/imdb): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) for a range of approximate
Bayesian methods with a CNN-LSTM model on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
data, with some interesting takeaways.


[<img style="float: right; width: 6em" src="https://storage.googleapis.com/posteriors/variational_continual_learning.png">](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb)

- [`continual_regression`](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb):
Expand Down
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ This directory contains examples of how to use the `posteriors` package.
- [`continual_lora`](continual_lora/): Uses `posteriors.laplace.diag_fisher` to avoid
catastrophic forgetting in fine-tuning [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf)
on a series of books from the [pg19](https://huggingface.co/datasets/pg19) dataset.
- [`imdb`](imdb/): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
for a range of approximate Bayesian methods on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
data.
- [`yelp`](yelp/): Compares a host of `posteriors` methods (highlighting the easy
exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).
- [`continual_regression`](continual_regression.ipynb): [Variational continual learning](https://arxiv.org/abs/1710.10628)
Expand Down
119 changes: 119 additions & 0 deletions examples/imdb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Cold posterior effect on IMDB data

We investigate a wide range of `posteriors` methods for a [CNN-LSTM](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) model on the [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb).
We run each methods for a variety of temperatures, that is targetting the tempered
posterior distribution $p(\theta \mid \mathcal{D})^{\frac{1}{T}}$ for $T \geq 0$, thus
investigating the so-called [cold posterior effect](https://arxiv.org/abs/2008.05912),
where improved predictive performance has been observed for $T<1$.

We observe improved significantly improved performance of Bayesian techniques over
gradient descent and notably that the cold posterior effect is significantly more
prominent for Gaussian approximations than SGMCMC variants.

## Methods

We train the CNN-LSTM model with 2.7m parameters on 25k reviews from the IMDB dataset
for binary classification of positive and negative reviews. We then use `posteriors` to
swap between the following methods:

- **map**: Maximum a posteriori (MAP) estimate of the parameters. Simple optimization
using [AdamW](https://arxiv.org/abs/1711.05101) via `posteriors.torchopt`.
- **laplace.diag_fisher**: Use the MAP estimate as the basis for a Laplace approximation
using a diagonal empirical Fisher covariance matrix. Additionally use
`posteriors.linearized_forward_diag` to asses a [linearized version](https://arxiv.org/abs/2106.14806)
of the posterior predictive distribution.
- **laplace.diag_ggn**: Same as above but using the diagonal Gauss-Newton Fisher matrix,
which is [equivalent to the conditional Fisher information matrix](https://arxiv.org/abs/1412.1193)
(integrated over labels rather than using empirical distribution). Again we assess
traditional and linearized Laplace.
- **vi.diag**: Variational inference using a diagonal Gaussian approximation. Optimized
using [AdamW](https://arxiv.org/abs/1711.05101) and also linearized variant.
- **sgmcmc.sghmc (Serial)**: [Stochastic gradient Hamiltonian Monte Carlo](https://arxiv.org/abs/1506.04696)
(SGHMC) with a Monte Carlo approximation to the posterior collected by running a single
trajectory (and removing a burn-in).
- **sgmcmc.sghmc (Parallel)**: Same as above but running 15 trajectories in parallel and
only collecting the final state of each trajectory.


All methods are run for a variety of temperatures $T \in \{0.03, 0.1, 0.3, 1.0, 3.0\}$.
Each train is over 30 epochs except serial SGHMC which uses 60 epochs to collect 27
samples from a single trajectory. Each method and temperature is run 5 times with
different random seeds, except for the parallel SGHMC in which we run over 35 seeds and
then bootstrap 5 ensembles of size 15. In all cases we use a diagonal Gaussian prior
with all variances set to 1/40.


## Results

We plot the test loss for each method and temperature in Figure 1.

<p align="center">
<img src="https://storage.googleapis.com/posteriors/cold_posterior_laplace_loss.png" width=30%">
<img src="https://storage.googleapis.com/posteriors/cold_posterior_vi_loss.png" width=30%">
<img src="https://storage.googleapis.com/posteriors/cold_posterior_sghmc_loss.png" width=30%">
<br>
<em>Figure 1. Test loss on IMDB data (lower is better) for varying temperatures.</em>
</p>

There are a few takeaways from the results:
- Bayesian methods (VI and SGMCMC) can significantly improve over gradient descent (MAP
and we also trained for the MLE which severely overfits and was omitted from the plot
for clarity).
- Regular Laplace does not perform well, the linearization helps somewhat with GGN out
performing Empirical Fisher.
- In contrast, the linearization is detrimental for VI which we posit is due to the VI
training acting in parameter space without knowledge of linearization.
- A strong cold posterior effect for Gaussian methods (VI + Laplace), only very mild
cold posterior effect for non-Gaussian Bayes methods (SGHMC).
- Parallel SGHMC outperforms Serial SGHMC and also evidence that both out perform [deep
ensemble](https://arxiv.org/abs/1612.01474) (which is obtained with parallel SGHMC and
$T=0$).


## Data

We download the IMDB dataset using [keras.datasets.imdb](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data).


## Model
We use the CNN-LSTM model from [Wenzel et al](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
which consists of an embedding layer, a convolutional layer, ReLU activation, max-pooling layer,
LSTM layer, and a final dense layer. In total there are 2.7m parameters. We use a diagonal
Gaussian prior with all variances set to 1/40.


## Code structure

- `lstm.py`: Simple, custom code for an LSTM layer that composes with `torch.func`.
- `model.py`: Specification of the CNN-LSTM model.
- `data.py`: Functions to load the IMDB data using `keras`.
- `train.py`: General script for training the model using `posteriors` methods
which can easily be swapped.
- `configs/` : Configuration files (python files) for the different methods.
- `train_runner.sh`: Bash script to run a series of training jobs.
- `combine_states_serial.py`: Combines parameter states from a single SGHMC trajectory
for ensembled forward calls.
- `combine_states_parallel.py`: Combines final parameter states from a multiple SGHMC
trajectories for ensembled forward calls.
- `test.py`: Calculate metrics on the IMDB test data for a given method.
- `test_runner.sh`: Bash script to run a series of testing jobs.
- `plot.py`: Generate Figure 1.
- `utils.py`: Small utility functions for logging.

Training code for a single train can be run from the root directory:
```bash
PYTHONPATH=. python examples/imdb/train.py --config examples/imdb/configs/laplace_diag_fisher.py --temperature 0.1 --seed 0 --device cuda:0
```
or by configuring multiple runs in `train_runner.sh` and running:
```bash
bash examples/imdb/train_runner.sh
```
Similarly testing code can be run from the root directory:
```bash
PYTHONPATH=. python examples/imdb/test.py --config examples/imdb/configs/laplace_diag_fisher.py --seed 0 --device cuda:0
```
or by configuring settings in `test_runner.sh` and running:
```bash
bash examples/imdb/test_runner.sh
```

65 changes: 65 additions & 0 deletions examples/imdb/combine_states_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pickle
import torch
import os
from optree import tree_map
import shutil


temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]


# SGHMC Parallel
base = "examples/imdb/results/sghmc_parallel/sghmc_parallel"
save_dir_base = "examples/imdb/results/sghmc_parallel"


bootstrap_seeds = [
(torch.multinomial(torch.ones(35), 15, replacement=False) + 1).tolist()
for _ in range(5)
]


for temp in temperatures:
temp_str = str(temp).replace(".", "-")

load_paths = []

for k, seeds in enumerate(bootstrap_seeds):
for seed in seeds:
spec_base = base
spec_base += f"_seed{seed}"
spec_base += f"_temp{temp_str}/"

match_str = (
"state_" if seed is None else "state"
) # don't need last save for serial

load_paths += [
spec_base + file for file in os.listdir(spec_base) if match_str in file
]

save_dir = save_dir_base + f"_seed{k+1}" + f"_temp{temp_str}/"

if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy(spec_base + "config.py", save_dir + "/config.py")

with open(f"{save_dir}/paths.txt", "w") as f:
f.write("\n".join(load_paths))

# Load states
states = [pickle.load(open(d, "rb")) for d in load_paths]

# Delete auxiliary info
for s in states:
del s.aux

# Move states to cpu
states = tree_map(lambda x: x.detach().to("cpu"), states)

# Combine states
combined_state = tree_map(lambda *x: torch.stack(x), *states)

# Save state
with open(f"{save_dir}state.pkl", "wb") as f:
pickle.dump(combined_state, f)
64 changes: 64 additions & 0 deletions examples/imdb/combine_states_serial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pickle
import torch
import os
from optree import tree_map
import shutil


temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]

# SGHMC Serial
bases = [
f"examples/imdb/results/sghmc_serial_seed{repeat_seed}"
for repeat_seed in range(1, 6)
]
seeds = [None]


for base in bases:
for temp in temperatures:
temp_str = str(temp).replace(".", "-")

load_paths = []

for seed in seeds:
spec_base = base

if seed is not None:
spec_base += f"_seed{seed}"

spec_base += f"_temp{temp_str}/"

match_str = (
"state_" if seed is None else "state"
) # don't need last save for serial

load_paths += [
spec_base + file for file in os.listdir(spec_base) if match_str in file
]

save_dir = base + f"_temp{temp_str}/"

if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy(spec_base + "config.py", save_dir + "/config.py")

with open(f"{save_dir}/paths.txt", "w") as f:
f.write("\n".join(load_paths))

# Load states
states = [pickle.load(open(d, "rb")) for d in load_paths]

# Delete auxiliary info
for s in states:
del s.aux

# Move states to cpu
states = tree_map(lambda x: x.detach().to("cpu"), states)

# Combine states
combined_state = tree_map(lambda *x: torch.stack(x), *states)

# Save state
with open(f"{save_dir}state.pkl", "wb") as f:
pickle.dump(combined_state, f)
72 changes: 72 additions & 0 deletions examples/imdb/configs/laplace_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import posteriors
from optree import tree_map

name = "laplace_fisher"
save_dir = "examples/imdb/results/" + name
params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params

prior_sd = torch.inf
batch_size = 32
burnin = None
save_frequency = None

method = posteriors.laplace.diag_fisher
config_args = {} # arguments for method.build (aside from log_posterior)
log_metrics = {} # dict containing names of metrics as keys and their paths in state as values
display_metric = "loss" # metric to display in tqdm progress bar

log_frequency = 100 # frequency at which to log metrics
log_window = 30 # window size for moving average

n_test_samples = 50
n_linearised_test_samples = 10000
epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag


def to_sd_diag(state, temperature=1.0):
return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag)


def forward(model, state, batch, temperature=1.0):
x, _ = batch
sd_diag = to_sd_diag(state, temperature)

sampled_params = posteriors.diag_normal_sample(
state.params, sd_diag, (n_test_samples,)
)

def model_func(p, x):
return torch.func.functional_call(model, p, x)

logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose(
0, 1
)
return logits


def forward_linearized(model, state, batch, temperature=1.0):
x, _ = batch
sd_diag = to_sd_diag(state, temperature)

def model_func_with_aux(p, x):
return torch.func.functional_call(model, p, x), torch.tensor([])

lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
model_func_with_aux,
state.params,
x,
sd_diag,
)

samps = torch.randn(
lin_mean.shape[0],
n_linearised_test_samples,
lin_mean.shape[1],
device=lin_mean.device,
)
lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
return lin_logits


forward_dict = {"Laplace EF": forward, "Laplace EF Linearized": forward_linearized}
Loading

0 comments on commit b4c5c0f

Please sign in to comment.