Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IMDB example #94

Merged
merged 8 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ on a series of books from the [pg19](https://huggingface.co/datasets/pg19) datas
Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a
sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).

[<img style="float: right; width: 6em" src="https://storage.googleapis.com/posteriors/cold_posterior_sghmc_loss.png">](https://github.com/normal-computing/posteriors/tree/main/examples/imdb)

- [`imdb`](https://github.com/normal-computing/posteriors/tree/main/examples/imdb): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) for a range of approximate
Bayesian methods with a CNN-LSTM model on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
data, with some interesting takeaways.


[<img style="float: right; width: 6em" src="https://storage.googleapis.com/posteriors/variational_continual_learning.png">](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb)

- [`continual_regression`](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb):
Expand Down
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ This directory contains examples of how to use the `posteriors` package.
- [`continual_lora`](continual_lora/): Uses `posteriors.laplace.diag_fisher` to avoid
catastrophic forgetting in fine-tuning [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf)
on a series of books from the [pg19](https://huggingface.co/datasets/pg19) dataset.
- [`imdb`](imdb/): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
for a range of approximate Bayesian methods on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
data.
- [`yelp`](yelp/): Compares a host of `posteriors` methods (highlighting the easy
exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).
- [`continual_regression`](continual_regression.ipynb): [Variational continual learning](https://arxiv.org/abs/1710.10628)
Expand Down
119 changes: 119 additions & 0 deletions examples/imdb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Cold posterior effect on IMDB data

We investigate a wide range of `posteriors` methods for a [CNN-LSTM](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) model on the [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb).
We run each methods for a variety of temperatures, that is targetting the tempered
posterior distribution $p(\theta \mid \mathcal{D})^{\frac{1}{T}}$ for $T \geq 0$, thus
investigating the so-called [cold posterior effect](https://arxiv.org/abs/2008.05912),
where improved predictive performance has been observed for $T<1$.

We observe improved significantly improved performance of Bayesian techniques over
gradient descent and notably that the cold posterior effect is significantly more
prominent for Gaussian approximations than SGMCMC variants.

## Methods

We train the CNN-LSTM model with 2.7m parameters on 25k reviews from the IMDB dataset
for binary classification of positive and negative reviews. We then use `posteriors` to
swap between the following methods:

- **map**: Maximum a posteriori (MAP) estimate of the parameters. Simple optimization
using [AdamW](https://arxiv.org/abs/1711.05101) via `posteriors.torchopt`.
- **laplace.diag_fisher**: Use the MAP estimate as the basis for a Laplace approximation
using a diagonal empirical Fisher covariance matrix. Additionally use
`posteriors.linearized_forward_diag` to asses a [linearized version](https://arxiv.org/abs/2106.14806)
of the posterior predictive distribution.
- **laplace.diag_ggn**: Same as above but using the diagonal Gauss-Newton Fisher matrix,
which is [equivalent to the conditional Fisher information matrix](https://arxiv.org/abs/1412.1193)
(integrated over labels rather than using empirical distribution). Again we assess
traditional and linearized Laplace.
- **vi.diag**: Variational inference using a diagonal Gaussian approximation. Optimized
using [AdamW](https://arxiv.org/abs/1711.05101) and also linearized variant.
- **sgmcmc.sghmc (Serial)**: [Stochastic gradient Hamiltonian Monte Carlo](https://arxiv.org/abs/1506.04696)
(SGHMC) with a Monte Carlo approximation to the posterior collected by running a single
trajectory (and removing a burn-in).
- **sgmcmc.sghmc (Parallel)**: Same as above but running 15 trajectories in parallel and
only collecting the final state of each trajectory.


All methods are run for a variety of temperatures $T \in \{0.03, 0.1, 0.3, 1.0, 3.0\}$.
Each train is over 30 epochs except serial SGHMC which uses 60 epochs to collect 27
samples from a single trajectory. Each method and temperature is run 5 times with
different random seeds, except for the parallel SGHMC in which we run over 35 seeds and
then bootstrap 5 ensembles of size 15. In all cases we use a diagonal Gaussian prior
with all variances set to 1/40.


## Results

We plot the test loss for each method and temperature in Figure 1.

<p align="center">
<img src="https://storage.googleapis.com/posteriors/cold_posterior_laplace_loss.png" width=30%">
<img src="https://storage.googleapis.com/posteriors/cold_posterior_vi_loss.png" width=30%">
<img src="https://storage.googleapis.com/posteriors/cold_posterior_sghmc_loss.png" width=30%">
<br>
<em>Figure 1. Test loss on IMDB data (lower is better) for varying temperatures.</em>
</p>

There are a few takeaways from the results:
- Bayesian methods (VI and SGMCMC) can significantly improve over gradient descent (MAP
and we also trained for the MLE which severely overfits and was omitted from the plot
for clarity).
- Regular Laplace does not perform well, the linearization helps somewhat with GGN out
performing Empirical Fisher.
- In contrast, the linearization is detrimental for VI which we posit is due to the VI
training acting in parameter space without knowledge of linearization.
- A strong cold posterior effect for Gaussian methods (VI + Laplace), only very mild
cold posterior effect for non-Gaussian Bayes methods (SGHMC).
- Parallel SGHMC outperforms Serial SGHMC and also evidence that both out perform [deep
ensemble](https://arxiv.org/abs/1612.01474) (which is obtained with parallel SGHMC and
$T=0$).


## Data

We download the IMDB dataset using [keras.datasets.imdb](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data).


## Model
We use the CNN-LSTM model from [Wenzel et al](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
which consists of an embedding layer, a convolutional layer, ReLU activation, max-pooling layer,
LSTM layer, and a final dense layer. In total there are 2.7m parameters. We use a diagonal
Gaussian prior with all variances set to 1/40.


## Code structure

- `lstm.py`: Simple, custom code for an LSTM layer that composes with `torch.func`.
- `model.py`: Specification of the CNN-LSTM model.
- `data.py`: Functions to load the IMDB data using `keras`.
- `train.py`: General script for training the model using `posteriors` methods
which can easily be swapped.
- `configs/` : Configuration files (python files) for the different methods.
- `train_runner.sh`: Bash script to run a series of training jobs.
- `combine_states_serial.py`: Combines parameter states from a single SGHMC trajectory
for ensembled forward calls.
- `combine_states_parallel.py`: Combines final parameter states from a multiple SGHMC
trajectories for ensembled forward calls.
- `test.py`: Calculate metrics on the IMDB test data for a given method.
- `test_runner.sh`: Bash script to run a series of testing jobs.
- `plot.py`: Generate Figure 1.
- `utils.py`: Small utility functions for logging.

Training code for a single train can be run from the root directory:
```bash
PYTHONPATH=. python examples/imdb/train.py --config examples/imdb/configs/laplace_diag_fisher.py --temperature 0.1 --seed 0 --device cuda:0
```
or by configuring multiple runs in `train_runner.sh` and running:
```bash
bash examples/imdb/train_runner.sh
```
Similarly testing code can be run from the root directory:
```bash
PYTHONPATH=. python examples/imdb/test.py --config examples/imdb/configs/laplace_diag_fisher.py --seed 0 --device cuda:0
```
or by configuring settings in `test_runner.sh` and running:
```bash
bash examples/imdb/test_runner.sh
```

65 changes: 65 additions & 0 deletions examples/imdb/combine_states_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pickle
import torch
import os
from optree import tree_map
import shutil


temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]


# SGHMC Parallel
base = "examples/imdb/results/sghmc_parallel/sghmc_parallel"
save_dir_base = "examples/imdb/results/sghmc_parallel"


bootstrap_seeds = [
(torch.multinomial(torch.ones(35), 15, replacement=False) + 1).tolist()
for _ in range(5)
]


for temp in temperatures:
temp_str = str(temp).replace(".", "-")

load_paths = []

for k, seeds in enumerate(bootstrap_seeds):
for seed in seeds:
spec_base = base
spec_base += f"_seed{seed}"
spec_base += f"_temp{temp_str}/"

match_str = (
"state_" if seed is None else "state"
) # don't need last save for serial

load_paths += [
spec_base + file for file in os.listdir(spec_base) if match_str in file
]

save_dir = save_dir_base + f"_seed{k+1}" + f"_temp{temp_str}/"

if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy(spec_base + "config.py", save_dir + "/config.py")

with open(f"{save_dir}/paths.txt", "w") as f:
f.write("\n".join(load_paths))

# Load states
states = [pickle.load(open(d, "rb")) for d in load_paths]

# Delete auxiliary info
for s in states:
del s.aux

# Move states to cpu
states = tree_map(lambda x: x.detach().to("cpu"), states)

# Combine states
combined_state = tree_map(lambda *x: torch.stack(x), *states)

# Save state
with open(f"{save_dir}state.pkl", "wb") as f:
pickle.dump(combined_state, f)
64 changes: 64 additions & 0 deletions examples/imdb/combine_states_serial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pickle
import torch
import os
from optree import tree_map
import shutil


temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]

# SGHMC Serial
bases = [
f"examples/imdb/results/sghmc_serial_seed{repeat_seed}"
for repeat_seed in range(1, 6)
]
seeds = [None]


for base in bases:
for temp in temperatures:
temp_str = str(temp).replace(".", "-")

load_paths = []

for seed in seeds:
spec_base = base

if seed is not None:
spec_base += f"_seed{seed}"

spec_base += f"_temp{temp_str}/"

match_str = (
"state_" if seed is None else "state"
) # don't need last save for serial

load_paths += [
spec_base + file for file in os.listdir(spec_base) if match_str in file
]

save_dir = base + f"_temp{temp_str}/"

if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy(spec_base + "config.py", save_dir + "/config.py")

with open(f"{save_dir}/paths.txt", "w") as f:
f.write("\n".join(load_paths))

# Load states
states = [pickle.load(open(d, "rb")) for d in load_paths]

# Delete auxiliary info
for s in states:
del s.aux

# Move states to cpu
states = tree_map(lambda x: x.detach().to("cpu"), states)

# Combine states
combined_state = tree_map(lambda *x: torch.stack(x), *states)

# Save state
with open(f"{save_dir}state.pkl", "wb") as f:
pickle.dump(combined_state, f)
72 changes: 72 additions & 0 deletions examples/imdb/configs/laplace_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import posteriors
from optree import tree_map

name = "laplace_fisher"
save_dir = "examples/imdb/results/" + name
params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params

prior_sd = torch.inf
batch_size = 32
burnin = None
save_frequency = None

method = posteriors.laplace.diag_fisher
config_args = {} # arguments for method.build (aside from log_posterior)
log_metrics = {} # dict containing names of metrics as keys and their paths in state as values
display_metric = "loss" # metric to display in tqdm progress bar

log_frequency = 100 # frequency at which to log metrics
log_window = 30 # window size for moving average

n_test_samples = 50
n_linearised_test_samples = 10000
epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag


def to_sd_diag(state, temperature=1.0):
return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag)


def forward(model, state, batch, temperature=1.0):
x, _ = batch
sd_diag = to_sd_diag(state, temperature)

sampled_params = posteriors.diag_normal_sample(
state.params, sd_diag, (n_test_samples,)
)

def model_func(p, x):
return torch.func.functional_call(model, p, x)

logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose(
0, 1
)
return logits


def forward_linearized(model, state, batch, temperature=1.0):
x, _ = batch
sd_diag = to_sd_diag(state, temperature)

def model_func_with_aux(p, x):
return torch.func.functional_call(model, p, x), torch.tensor([])

lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
model_func_with_aux,
state.params,
x,
sd_diag,
)

samps = torch.randn(
lin_mean.shape[0],
n_linearised_test_samples,
lin_mean.shape[1],
device=lin_mean.device,
)
lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
return lin_logits


forward_dict = {"Laplace EF": forward, "Laplace EF Linearized": forward_linearized}
Loading