diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 9dc73de5..2bc52ee5 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -32,6 +32,13 @@ on a series of books from the [pg19](https://huggingface.co/datasets/pg19) datas
Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a
sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).
+[](https://github.com/normal-computing/posteriors/tree/main/examples/imdb)
+
+- [`imdb`](https://github.com/normal-computing/posteriors/tree/main/examples/imdb): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) for a range of approximate
+Bayesian methods with a CNN-LSTM model on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
+data, with some interesting takeaways.
+
+
[](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb)
- [`continual_regression`](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb):
diff --git a/examples/README.md b/examples/README.md
index 07c4b4c9..42ea63b3 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -4,6 +4,9 @@ This directory contains examples of how to use the `posteriors` package.
- [`continual_lora`](continual_lora/): Uses `posteriors.laplace.diag_fisher` to avoid
catastrophic forgetting in fine-tuning [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf)
on a series of books from the [pg19](https://huggingface.co/datasets/pg19) dataset.
+- [`imdb`](imdb/): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
+for a range of approximate Bayesian methods on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data)
+data.
- [`yelp`](yelp/): Compares a host of `posteriors` methods (highlighting the easy
exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch).
- [`continual_regression`](continual_regression.ipynb): [Variational continual learning](https://arxiv.org/abs/1710.10628)
diff --git a/examples/imdb/README.md b/examples/imdb/README.md
new file mode 100644
index 00000000..783f5a40
--- /dev/null
+++ b/examples/imdb/README.md
@@ -0,0 +1,119 @@
+# Cold posterior effect on IMDB data
+
+We investigate a wide range of `posteriors` methods for a [CNN-LSTM](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) model on the [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb).
+We run each methods for a variety of temperatures, that is targetting the tempered
+posterior distribution $p(\theta \mid \mathcal{D})^{\frac{1}{T}}$ for $T \geq 0$, thus
+investigating the so-called [cold posterior effect](https://arxiv.org/abs/2008.05912),
+where improved predictive performance has been observed for $T<1$.
+
+We observe improved significantly improved performance of Bayesian techniques over
+gradient descent and notably that the cold posterior effect is significantly more
+prominent for Gaussian approximations than SGMCMC variants.
+
+## Methods
+
+We train the CNN-LSTM model with 2.7m parameters on 25k reviews from the IMDB dataset
+for binary classification of positive and negative reviews. We then use `posteriors` to
+swap between the following methods:
+
+- **map**: Maximum a posteriori (MAP) estimate of the parameters. Simple optimization
+using [AdamW](https://arxiv.org/abs/1711.05101) via `posteriors.torchopt`.
+- **laplace.diag_fisher**: Use the MAP estimate as the basis for a Laplace approximation
+using a diagonal empirical Fisher covariance matrix. Additionally use
+`posteriors.linearized_forward_diag` to asses a [linearized version](https://arxiv.org/abs/2106.14806)
+of the posterior predictive distribution.
+- **laplace.diag_ggn**: Same as above but using the diagonal Gauss-Newton Fisher matrix,
+which is [equivalent to the conditional Fisher information matrix](https://arxiv.org/abs/1412.1193)
+(integrated over labels rather than using empirical distribution). Again we assess
+traditional and linearized Laplace.
+- **vi.diag**: Variational inference using a diagonal Gaussian approximation. Optimized
+using [AdamW](https://arxiv.org/abs/1711.05101) and also linearized variant.
+- **sgmcmc.sghmc (Serial)**: [Stochastic gradient Hamiltonian Monte Carlo](https://arxiv.org/abs/1506.04696)
+(SGHMC) with a Monte Carlo approximation to the posterior collected by running a single
+trajectory (and removing a burn-in).
+- **sgmcmc.sghmc (Parallel)**: Same as above but running 15 trajectories in parallel and
+only collecting the final state of each trajectory.
+
+
+All methods are run for a variety of temperatures $T \in \{0.03, 0.1, 0.3, 1.0, 3.0\}$.
+Each train is over 30 epochs except serial SGHMC which uses 60 epochs to collect 27
+samples from a single trajectory. Each method and temperature is run 5 times with
+different random seeds, except for the parallel SGHMC in which we run over 35 seeds and
+then bootstrap 5 ensembles of size 15. In all cases we use a diagonal Gaussian prior
+with all variances set to 1/40.
+
+
+## Results
+
+We plot the test loss for each method and temperature in Figure 1.
+
+
+
+
+
+
+ Figure 1. Test loss on IMDB data (lower is better) for varying temperatures.
+
+
+There are a few takeaways from the results:
+- Bayesian methods (VI and SGMCMC) can significantly improve over gradient descent (MAP
+and we also trained for the MLE which severely overfits and was omitted from the plot
+for clarity).
+- Regular Laplace does not perform well, the linearization helps somewhat with GGN out
+performing Empirical Fisher.
+- In contrast, the linearization is detrimental for VI which we posit is due to the VI
+training acting in parameter space without knowledge of linearization.
+- A strong cold posterior effect for Gaussian methods (VI + Laplace), only very mild
+cold posterior effect for non-Gaussian Bayes methods (SGHMC).
+- Parallel SGHMC outperforms Serial SGHMC and also evidence that both out perform [deep
+ensemble](https://arxiv.org/abs/1612.01474) (which is obtained with parallel SGHMC and
+$T=0$).
+
+
+## Data
+
+We download the IMDB dataset using [keras.datasets.imdb](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data).
+
+
+## Model
+We use the CNN-LSTM model from [Wenzel et al](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf)
+which consists of an embedding layer, a convolutional layer, ReLU activation, max-pooling layer,
+LSTM layer, and a final dense layer. In total there are 2.7m parameters. We use a diagonal
+Gaussian prior with all variances set to 1/40.
+
+
+## Code structure
+
+- `lstm.py`: Simple, custom code for an LSTM layer that composes with `torch.func`.
+- `model.py`: Specification of the CNN-LSTM model.
+- `data.py`: Functions to load the IMDB data using `keras`.
+- `train.py`: General script for training the model using `posteriors` methods
+ which can easily be swapped.
+- `configs/` : Configuration files (python files) for the different methods.
+- `train_runner.sh`: Bash script to run a series of training jobs.
+- `combine_states_serial.py`: Combines parameter states from a single SGHMC trajectory
+for ensembled forward calls.
+- `combine_states_parallel.py`: Combines final parameter states from a multiple SGHMC
+trajectories for ensembled forward calls.
+- `test.py`: Calculate metrics on the IMDB test data for a given method.
+- `test_runner.sh`: Bash script to run a series of testing jobs.
+- `plot.py`: Generate Figure 1.
+- `utils.py`: Small utility functions for logging.
+
+Training code for a single train can be run from the root directory:
+```bash
+PYTHONPATH=. python examples/imdb/train.py --config examples/imdb/configs/laplace_diag_fisher.py --temperature 0.1 --seed 0 --device cuda:0
+```
+or by configuring multiple runs in `train_runner.sh` and running:
+```bash
+bash examples/imdb/train_runner.sh
+```
+Similarly testing code can be run from the root directory:
+```bash
+PYTHONPATH=. python examples/imdb/test.py --config examples/imdb/configs/laplace_diag_fisher.py --seed 0 --device cuda:0
+```
+or by configuring settings in `test_runner.sh` and running:
+```bash
+bash examples/imdb/test_runner.sh
+```
+
diff --git a/examples/imdb/combine_states_parallel.py b/examples/imdb/combine_states_parallel.py
new file mode 100644
index 00000000..29727e1e
--- /dev/null
+++ b/examples/imdb/combine_states_parallel.py
@@ -0,0 +1,65 @@
+import pickle
+import torch
+import os
+from optree import tree_map
+import shutil
+
+
+temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]
+
+
+# SGHMC Parallel
+base = "examples/imdb/results/sghmc_parallel/sghmc_parallel"
+save_dir_base = "examples/imdb/results/sghmc_parallel"
+
+
+bootstrap_seeds = [
+ (torch.multinomial(torch.ones(35), 15, replacement=False) + 1).tolist()
+ for _ in range(5)
+]
+
+
+for temp in temperatures:
+ temp_str = str(temp).replace(".", "-")
+
+ load_paths = []
+
+ for k, seeds in enumerate(bootstrap_seeds):
+ for seed in seeds:
+ spec_base = base
+ spec_base += f"_seed{seed}"
+ spec_base += f"_temp{temp_str}/"
+
+ match_str = (
+ "state_" if seed is None else "state"
+ ) # don't need last save for serial
+
+ load_paths += [
+ spec_base + file for file in os.listdir(spec_base) if match_str in file
+ ]
+
+ save_dir = save_dir_base + f"_seed{k+1}" + f"_temp{temp_str}/"
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ shutil.copy(spec_base + "config.py", save_dir + "/config.py")
+
+ with open(f"{save_dir}/paths.txt", "w") as f:
+ f.write("\n".join(load_paths))
+
+ # Load states
+ states = [pickle.load(open(d, "rb")) for d in load_paths]
+
+ # Delete auxiliary info
+ for s in states:
+ del s.aux
+
+ # Move states to cpu
+ states = tree_map(lambda x: x.detach().to("cpu"), states)
+
+ # Combine states
+ combined_state = tree_map(lambda *x: torch.stack(x), *states)
+
+ # Save state
+ with open(f"{save_dir}state.pkl", "wb") as f:
+ pickle.dump(combined_state, f)
diff --git a/examples/imdb/combine_states_serial.py b/examples/imdb/combine_states_serial.py
new file mode 100644
index 00000000..681ab23d
--- /dev/null
+++ b/examples/imdb/combine_states_serial.py
@@ -0,0 +1,64 @@
+import pickle
+import torch
+import os
+from optree import tree_map
+import shutil
+
+
+temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]
+
+# SGHMC Serial
+bases = [
+ f"examples/imdb/results/sghmc_serial_seed{repeat_seed}"
+ for repeat_seed in range(1, 6)
+]
+seeds = [None]
+
+
+for base in bases:
+ for temp in temperatures:
+ temp_str = str(temp).replace(".", "-")
+
+ load_paths = []
+
+ for seed in seeds:
+ spec_base = base
+
+ if seed is not None:
+ spec_base += f"_seed{seed}"
+
+ spec_base += f"_temp{temp_str}/"
+
+ match_str = (
+ "state_" if seed is None else "state"
+ ) # don't need last save for serial
+
+ load_paths += [
+ spec_base + file for file in os.listdir(spec_base) if match_str in file
+ ]
+
+ save_dir = base + f"_temp{temp_str}/"
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ shutil.copy(spec_base + "config.py", save_dir + "/config.py")
+
+ with open(f"{save_dir}/paths.txt", "w") as f:
+ f.write("\n".join(load_paths))
+
+ # Load states
+ states = [pickle.load(open(d, "rb")) for d in load_paths]
+
+ # Delete auxiliary info
+ for s in states:
+ del s.aux
+
+ # Move states to cpu
+ states = tree_map(lambda x: x.detach().to("cpu"), states)
+
+ # Combine states
+ combined_state = tree_map(lambda *x: torch.stack(x), *states)
+
+ # Save state
+ with open(f"{save_dir}state.pkl", "wb") as f:
+ pickle.dump(combined_state, f)
diff --git a/examples/imdb/configs/laplace_fisher.py b/examples/imdb/configs/laplace_fisher.py
new file mode 100644
index 00000000..40d42e6c
--- /dev/null
+++ b/examples/imdb/configs/laplace_fisher.py
@@ -0,0 +1,72 @@
+import torch
+import posteriors
+from optree import tree_map
+
+name = "laplace_fisher"
+save_dir = "examples/imdb/results/" + name
+params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params
+
+prior_sd = torch.inf
+batch_size = 32
+burnin = None
+save_frequency = None
+
+method = posteriors.laplace.diag_fisher
+config_args = {} # arguments for method.build (aside from log_posterior)
+log_metrics = {} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "loss" # metric to display in tqdm progress bar
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+n_test_samples = 50
+n_linearised_test_samples = 10000
+epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag
+
+
+def to_sd_diag(state, temperature=1.0):
+ return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag)
+
+
+def forward(model, state, batch, temperature=1.0):
+ x, _ = batch
+ sd_diag = to_sd_diag(state, temperature)
+
+ sampled_params = posteriors.diag_normal_sample(
+ state.params, sd_diag, (n_test_samples,)
+ )
+
+ def model_func(p, x):
+ return torch.func.functional_call(model, p, x)
+
+ logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose(
+ 0, 1
+ )
+ return logits
+
+
+def forward_linearized(model, state, batch, temperature=1.0):
+ x, _ = batch
+ sd_diag = to_sd_diag(state, temperature)
+
+ def model_func_with_aux(p, x):
+ return torch.func.functional_call(model, p, x), torch.tensor([])
+
+ lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
+ model_func_with_aux,
+ state.params,
+ x,
+ sd_diag,
+ )
+
+ samps = torch.randn(
+ lin_mean.shape[0],
+ n_linearised_test_samples,
+ lin_mean.shape[1],
+ device=lin_mean.device,
+ )
+ lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
+ return lin_logits
+
+
+forward_dict = {"Laplace EF": forward, "Laplace EF Linearized": forward_linearized}
diff --git a/examples/imdb/configs/laplace_ggn.py b/examples/imdb/configs/laplace_ggn.py
new file mode 100644
index 00000000..f06760ea
--- /dev/null
+++ b/examples/imdb/configs/laplace_ggn.py
@@ -0,0 +1,72 @@
+import torch
+import posteriors
+from optree import tree_map
+
+name = "laplace_ggn"
+save_dir = "examples/imdb/results/" + name
+params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params
+
+prior_sd = torch.inf
+batch_size = 32
+burnin = None
+save_frequency = None
+
+method = posteriors.laplace.diag_ggn
+config_args = {} # arguments for method.build (aside from log_posterior)
+log_metrics = {} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "loss" # metric to display in tqdm progress bar
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+n_test_samples = 50
+n_linearised_test_samples = 10000
+epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag
+
+
+def to_sd_diag(state, temperature=1.0):
+ return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag)
+
+
+def forward(model, state, batch, temperature=1.0):
+ x, _ = batch
+ sd_diag = to_sd_diag(state, temperature)
+
+ sampled_params = posteriors.diag_normal_sample(
+ state.params, sd_diag, (n_test_samples,)
+ )
+
+ def model_func(p, x):
+ return torch.func.functional_call(model, p, x)
+
+ logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose(
+ 0, 1
+ )
+ return logits
+
+
+def forward_linearized(model, state, batch, temperature=1.0):
+ x, _ = batch
+ sd_diag = to_sd_diag(state, temperature)
+
+ def model_func_with_aux(p, x):
+ return torch.func.functional_call(model, p, x), torch.tensor([])
+
+ lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
+ model_func_with_aux,
+ state.params,
+ x,
+ sd_diag,
+ )
+
+ samps = torch.randn(
+ lin_mean.shape[0],
+ n_linearised_test_samples,
+ lin_mean.shape[1],
+ device=lin_mean.device,
+ )
+ lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
+ return lin_logits
+
+
+forward_dict = {"Laplace GGN": forward, "Laplace GGN Linearized": forward_linearized}
diff --git a/examples/imdb/configs/map.py b/examples/imdb/configs/map.py
new file mode 100644
index 00000000..02813e11
--- /dev/null
+++ b/examples/imdb/configs/map.py
@@ -0,0 +1,36 @@
+import posteriors
+import torchopt
+import torch
+
+name = "map"
+save_dir = "examples/imdb/results/" + name
+params_dir = None # directory to load state containing initialisation params
+
+
+prior_sd = (1 / 40) ** 0.5
+batch_size = 32
+burnin = None
+save_frequency = None
+
+
+method = posteriors.torchopt
+config_args = {
+ "optimizer": torchopt.adamw(lr=1e-3, maximize=True)
+} # arguments for method.build (aside from log_posterior)
+log_metrics = {
+ "log_post": "loss",
+} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "log_post" # metric to display in tqdm progress bar
+
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+
+def forward(model, state, batch):
+ x, _ = batch
+ logits = torch.func.functional_call(model, state.params, x)
+ return logits.unsqueeze(1)
+
+
+forward_dict = {"MAP": forward}
diff --git a/examples/imdb/configs/mle.py b/examples/imdb/configs/mle.py
new file mode 100644
index 00000000..bfb9e5a1
--- /dev/null
+++ b/examples/imdb/configs/mle.py
@@ -0,0 +1,36 @@
+import posteriors
+import torchopt
+import torch
+
+name = "mle"
+save_dir = "examples/imdb/results/" + name
+params_dir = None # directory to load state containing initialisation params
+
+
+prior_sd = torch.inf
+batch_size = 32
+burnin = None
+save_frequency = None
+
+
+method = posteriors.torchopt
+config_args = {
+ "optimizer": torchopt.adamw(lr=1e-3, maximize=True)
+} # arguments for method.build (aside from log_posterior)
+log_metrics = {
+ "log_post": "loss",
+} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "log_post" # metric to display in tqdm progress bar
+
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+
+def forward(model, state, batch):
+ x, _ = batch
+ logits = torch.func.functional_call(model, state.params, x)
+ return logits.unsqueeze(1)
+
+
+forward_dict = {"MLE": forward}
diff --git a/examples/imdb/configs/sghmc_parallel.py b/examples/imdb/configs/sghmc_parallel.py
new file mode 100644
index 00000000..e7ed818c
--- /dev/null
+++ b/examples/imdb/configs/sghmc_parallel.py
@@ -0,0 +1,42 @@
+import posteriors
+import torch
+
+name = "sghmc_parallel"
+save_dir = "examples/imdb/results/sghmc_parallel/" + name
+params_dir = None # directory to load state containing initialisation params
+
+prior_sd = (1 / 40) ** 0.5
+batch_size = 32
+burnin = None
+save_frequency = None
+
+lr = 1e-1
+
+method = posteriors.sgmcmc.sghmc
+config_args = {
+ "lr": lr,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "temperature": None, # None temperature gets set by train.py
+ "momenta": 0.0,
+} # arguments for method.build (aside from log_posterior)
+log_metrics = {
+ "log_post": "log_posterior",
+} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "log_post" # metric to display in tqdm progress bar
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+
+def forward(model, state, batch):
+ x, _ = batch
+
+ def model_func(p, x):
+ return torch.func.functional_call(model, p, x)
+
+ logits = torch.vmap(model_func, in_dims=(0, None))(state.params, x).transpose(0, 1)
+ return logits
+
+
+forward_dict = {"Parallel SGHMC": forward}
diff --git a/examples/imdb/configs/sghmc_serial.py b/examples/imdb/configs/sghmc_serial.py
new file mode 100644
index 00000000..aec7e4e8
--- /dev/null
+++ b/examples/imdb/configs/sghmc_serial.py
@@ -0,0 +1,43 @@
+import posteriors
+import torch
+
+name = "sghmc_serial"
+save_dir = "examples/imdb/results/" + name
+params_dir = None # directory to load state containing initialisation params
+
+
+prior_sd = (1 / 40) ** 0.5
+batch_size = 32
+burnin = 20000
+save_frequency = 1000
+
+lr = 1e-1
+
+method = posteriors.sgmcmc.sghmc
+config_args = {
+ "lr": lr,
+ "alpha": 1.0,
+ "beta": 0.0,
+ "temperature": None, # None temperature gets set by train.py
+ "momenta": 0.0,
+} # arguments for method.build (aside from log_posterior)
+log_metrics = {
+ "log_post": "log_posterior",
+} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "log_post" # metric to display in tqdm progress bar
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+
+def forward(model, state, batch):
+ x, _ = batch
+
+ def model_func(p, x):
+ return torch.func.functional_call(model, p, x)
+
+ logits = torch.vmap(model_func, in_dims=(0, None))(state.params, x).transpose(0, 1)
+ return logits
+
+
+forward_dict = {"Serial SGHMC": forward}
diff --git a/examples/imdb/configs/vi.py b/examples/imdb/configs/vi.py
new file mode 100644
index 00000000..082ddadb
--- /dev/null
+++ b/examples/imdb/configs/vi.py
@@ -0,0 +1,81 @@
+import posteriors
+import torchopt
+import torch
+from optree import tree_map
+
+name = "vi"
+save_dir = "examples/imdb/results/" + name
+params_dir = None # directory to load state containing initialisation params
+
+
+prior_sd = (1 / 40) ** 0.5
+batch_size = 32
+burnin = None
+save_frequency = None
+
+method = posteriors.vi.diag
+config_args = {
+ "optimizer": torchopt.adamw(lr=1e-3),
+ "temperature": None, # None temperature gets set by train.py
+ "n_samples": 1,
+ "stl": True,
+ "init_log_sds": -3,
+} # arguments for method.build (aside from log_posterior)
+log_metrics = {
+ "nelbo": "nelbo",
+} # dict containing names of metrics as keys and their paths in state as values
+display_metric = "nelbo" # metric to display in tqdm progress bar
+
+log_frequency = 100 # frequency at which to log metrics
+log_window = 30 # window size for moving average
+
+n_test_samples = 50
+n_linearised_test_samples = 10000
+
+
+def to_sd_diag(state):
+ return tree_map(lambda x: x.exp(), state.log_sd_diag)
+
+
+def forward(model, state, batch):
+ x, _ = batch
+ sd_diag = to_sd_diag(state)
+
+ sampled_params = posteriors.diag_normal_sample(
+ state.params, sd_diag, (n_test_samples,)
+ )
+
+ def model_func(p, x):
+ return torch.func.functional_call(model, p, x)
+
+ logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose(
+ 0, 1
+ )
+ return logits
+
+
+def forward_linearized(model, state, batch):
+ x, _ = batch
+ sd_diag = to_sd_diag(state)
+
+ def model_func_with_aux(p, x):
+ return torch.func.functional_call(model, p, x), torch.tensor([])
+
+ lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
+ model_func_with_aux,
+ state.params,
+ x,
+ sd_diag,
+ )
+
+ samps = torch.randn(
+ lin_mean.shape[0],
+ n_linearised_test_samples,
+ lin_mean.shape[1],
+ device=lin_mean.device,
+ )
+ lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
+ return lin_logits
+
+
+forward_dict = {"VI": forward, "VI Linearized": forward_linearized}
diff --git a/examples/imdb/data.py b/examples/imdb/data.py
new file mode 100644
index 00000000..339655c5
--- /dev/null
+++ b/examples/imdb/data.py
@@ -0,0 +1,27 @@
+import torch
+from torch.utils.data import DataLoader, TensorDataset
+from keras.datasets import imdb
+from keras.preprocessing.sequence import pad_sequences
+
+
+def load_imdb_dataset(batch_size=32, max_features=20000, max_len=100):
+ # Load and pad IMDB dataset
+ (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
+ x_train = pad_sequences(x_train, maxlen=max_len)
+ x_test = pad_sequences(x_test, maxlen=max_len)
+
+ # Convert data to PyTorch tensors
+ train_data = torch.tensor(x_train, dtype=torch.long)
+ train_labels = torch.tensor(y_train, dtype=torch.long)
+ test_data = torch.tensor(x_test, dtype=torch.long)
+ test_labels = torch.tensor(y_test, dtype=torch.long)
+
+ # Create Tensor datasets
+ train_dataset = TensorDataset(train_data, train_labels)
+ test_dataset = TensorDataset(test_data, test_labels)
+
+ # Create DataLoaders
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+ test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
+
+ return train_dataloader, test_dataloader
diff --git a/examples/imdb/lstm.py b/examples/imdb/lstm.py
new file mode 100644
index 00000000..7804cbc4
--- /dev/null
+++ b/examples/imdb/lstm.py
@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+
+
+class CustomLSTMCell(nn.Module):
+ def __init__(self, input_size, hidden_size):
+ super(CustomLSTMCell, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
+ self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
+ self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size)
+ self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
+
+ def forward(self, input, states):
+ h_prev, c_prev = states
+ combined = torch.cat((input, h_prev), 1)
+
+ f_t = torch.sigmoid(self.forget_gate(combined))
+ i_t = torch.sigmoid(self.input_gate(combined))
+ c_tilde = torch.tanh(self.cell_gate(combined))
+ c_t = f_t * c_prev + i_t * c_tilde
+ o_t = torch.sigmoid(self.output_gate(combined))
+ h_t = o_t * torch.tanh(c_t)
+
+ return h_t, c_t
+
+
+class CustomLSTM(nn.Module):
+ def __init__(self, input_size, hidden_size, batch_first=False):
+ super(CustomLSTM, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.batch_first = batch_first
+ self.lstm_cell = CustomLSTMCell(input_size, hidden_size)
+
+ def forward(self, input, initial_states=None):
+ if initial_states is None:
+ initial_h = torch.zeros(
+ 1,
+ input.size(0 if self.batch_first else 1),
+ self.hidden_size,
+ device=input.device,
+ )
+ initial_c = torch.zeros(
+ 1,
+ input.size(0 if self.batch_first else 1),
+ self.hidden_size,
+ device=input.device,
+ )
+ else:
+ initial_h, initial_c = initial_states
+
+ # Ensure we are working with single layer, single direction states
+ initial_h = initial_h.squeeze(0)
+ initial_c = initial_c.squeeze(0)
+
+ if self.batch_first:
+ input = input.transpose(
+ 0, 1
+ ) # Convert (batch, seq_len, feature) to (seq_len, batch, feature)
+
+ outputs = []
+ h_t, c_t = initial_h, initial_c
+
+ for i in range(
+ input.shape[0]
+ ): # input is expected to be (seq_len, batch, input_size)
+ h_t, c_t = self.lstm_cell(input[i], (h_t, c_t))
+ outputs.append(h_t.unsqueeze(0))
+
+ outputs = torch.cat(outputs, 0)
+
+ if self.batch_first:
+ outputs = outputs.transpose(
+ 0, 1
+ ) # Convert back to (batch, seq_len, feature)
+
+ return outputs, (h_t, c_t)
+
+
+# Test equivalence
+def test_lstm_equivalence():
+ input_size = 10
+ hidden_size = 20
+ seq_len = 5
+ batch_size = 1
+
+ # Initialize both LSTMs
+ torch_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
+ custom_lstm = CustomLSTM(input_size, hidden_size, batch_first=True)
+
+ # Manually setting the same weights and biases
+ with torch.no_grad():
+ # Copy weights and biases from torch LSTM to custom LSTM
+ gates = ["input_gate", "forget_gate", "cell_gate", "output_gate"]
+
+ for idx, gate in enumerate(gates):
+ start = idx * hidden_size
+ end = (idx + 1) * hidden_size
+
+ getattr(custom_lstm.lstm_cell, gate).weight.data[:, :input_size].copy_(
+ torch_lstm.weight_ih_l0[start:end]
+ )
+ getattr(custom_lstm.lstm_cell, gate).weight.data[:, input_size:].copy_(
+ torch_lstm.weight_hh_l0[start:end]
+ )
+ getattr(custom_lstm.lstm_cell, gate).bias.data.copy_(
+ torch_lstm.bias_ih_l0[start:end] + torch_lstm.bias_hh_l0[start:end]
+ )
+ # Dummy input
+ inputs = torch.randn(batch_size, seq_len, input_size)
+
+ # Custom LSTM forward pass
+ custom_outputs, (custom_hn, custom_cn) = custom_lstm(inputs)
+
+ # Torch LSTM forward pass
+ torch_outputs, (torch_hn, torch_cn) = torch_lstm(inputs)
+
+ # Check outputs and final hidden and cell states
+ assert torch.allclose(custom_outputs, torch_outputs, atol=1e-6), "Output mismatch"
+ assert torch.allclose(custom_hn, torch_hn), "Hidden state mismatch"
+ assert torch.allclose(custom_cn, torch_cn), "Cell state mismatch"
+
+ print("Test passed: Custom LSTM and torch.nn.LSTM outputs are equivalent!")
+
+
+if __name__ == "__main__":
+ test_lstm_equivalence()
diff --git a/examples/imdb/model.py b/examples/imdb/model.py
new file mode 100644
index 00000000..651d8894
--- /dev/null
+++ b/examples/imdb/model.py
@@ -0,0 +1,72 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# torch.nn.LSTM does not work well with torch.func https://github.com/pytorch/pytorch/issues/105982
+# so use custom simplfied LSTM instead
+from examples.imdb.lstm import CustomLSTM
+
+
+class CNNLSTM(nn.Module):
+ def __init__(
+ self,
+ num_classes,
+ max_features=20000,
+ embedding_size=128,
+ cell_size=128,
+ num_filters=64,
+ kernel_size=5,
+ pool_size=4,
+ use_swish=False,
+ use_maxpool=True,
+ ):
+ super(CNNLSTM, self).__init__()
+ self.embedding = nn.Embedding(
+ num_embeddings=max_features, embedding_dim=embedding_size
+ )
+ self.conv1d = nn.Conv1d(
+ in_channels=embedding_size,
+ out_channels=num_filters,
+ kernel_size=kernel_size,
+ )
+ self.use_swish = use_swish
+ self.use_maxpool = use_maxpool
+ if use_maxpool:
+ self.maxpool = nn.MaxPool1d(kernel_size=pool_size, stride=pool_size)
+ self.lstm = CustomLSTM(
+ input_size=num_filters, hidden_size=cell_size, batch_first=True
+ )
+
+ self.fc = nn.Linear(in_features=cell_size, out_features=num_classes)
+
+ def forward(self, x):
+ # Embedding
+ x = self.embedding(x) # Shape: [batch_size, seq_length, embedding_size]
+ x = x.permute(
+ 0, 2, 1
+ ) # Shape: [batch_size, embedding_size, seq_length] to match Conv1d input
+
+ # Convolution
+ x = self.conv1d(x)
+ if self.use_swish:
+ x = x * torch.sigmoid(x) # Swish activation function
+ else:
+ x = F.relu(x)
+
+ # Pooling
+ if self.use_maxpool:
+ x = self.maxpool(x)
+
+ # Reshape for LSTM
+ x = x.permute(0, 2, 1) # Shape: [batch_size, seq_length, num_filters]
+
+ # LSTM
+ output, _ = self.lstm(x)
+
+ # Take the last sequence output
+ last_output = output[:, -1, :]
+
+ # Fully connected layer
+ logits = self.fc(last_output)
+
+ return logits
diff --git a/examples/imdb/plot.py b/examples/imdb/plot.py
new file mode 100644
index 00000000..1c8eeec1
--- /dev/null
+++ b/examples/imdb/plot.py
@@ -0,0 +1,204 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.ticker import FormatStrFormatter
+import json
+import os
+
+plt.rcParams["font.family"] = "Times New Roman"
+
+
+temperatures = [0.03, 0.1, 0.3, 1.0, 3.0]
+temp_strs = [str(temp).replace(".", "-") for temp in temperatures]
+seeds = list(range(1, 6))
+
+
+# [base directory, tempered bool, colors]
+bases = [
+ # ["examples/imdb/results/mle", False, ["grey"]],
+ ["examples/imdb/results/map", False, ["grey"]],
+ ["examples/imdb/results/laplace_fisher", True, ["royalblue", "deepskyblue"]],
+ ["examples/imdb/results/laplace_ggn", True, ["purple", "mediumvioletred"]],
+]
+
+# bases = [
+# ["examples/imdb/results/mle", False, ["grey"]],
+# ["examples/imdb/results/map", False, ["grey"]],
+# ["examples/imdb/results/vi", True, ["forestgreen", "darkkhaki"]],
+# ]
+
+# bases = [
+# # ["examples/imdb/results/mle", False, ["grey"]],
+# ["examples/imdb/results/map", False, ["grey"]],
+# ["examples/imdb/results/sghmc_serial", True, ["firebrick"]],
+# ["examples/imdb/results/sghmc_parallel", True, ["tomato"]],
+# ]
+
+with_mle = "mle" in bases[0][0]
+
+if with_mle:
+ ylims = {"loss": (0.305, 1.2), "accuracy": (0.47, 0.88)}
+else:
+ ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)}
+
+save_name = bases[-1][0].split("/")[-1].split("_")[0]
+
+
+test_dicts = {}
+colour_dict = {}
+
+for base, tempered, colours in bases:
+ for seed in seeds:
+ seed_base = base + "_seed" + str(seed) if seed is not None else base
+ match_str = seed_base + "_temp" + temp_strs[0] if tempered else seed_base
+ versions = [file for file in os.listdir(match_str) if "test" in file]
+ versions = [v.strip(".json").split("_")[1] for v in versions]
+ versions.sort()
+
+ for k, name in enumerate(versions):
+ colour_dict[name] = colours[k]
+ if tempered:
+ for temp in temperatures:
+ if temp not in test_dicts:
+ key = f"{name}_temp{str(temp).replace('.', '-')}_seed{seed}"
+ test_dicts[key] = json.load(
+ open(
+ f"{seed_base}_temp{str(temp).replace('.', '-')}/test_{name}.json"
+ )
+ )
+ else:
+ key = f"{name}_tempNA_seed{seed}"
+ test_dicts[key] = json.load(open(f"{seed_base}/test_{name}.json"))
+
+
+line_styles = ["--", ":", "-.", "-"]
+
+
+def sorted_set(l_in):
+ return sorted(set(l_in), key=l_in.index)
+
+
+method_names = sorted_set([key.split("_")[0] for key in test_dicts.keys()])
+
+
+def metric_vals_dict(metric_name):
+ md = {}
+
+ for method_name in method_names:
+ method_name_keys = [
+ key for key in test_dicts.keys() if key.split("_")[0] == method_name
+ ]
+
+ temperature_keys = sorted_set([key.split("_")[1] for key in method_name_keys])
+
+ method_dict = {}
+
+ for tk in temperature_keys:
+ temp = tk.strip("temp").replace("-", ".")
+
+ method_dict[temp] = [
+ np.mean(test_dicts[key][metric_name])
+ for key in method_name_keys
+ if key.split("_")[1] == tk
+ ]
+
+ md |= {method_name: method_dict}
+ return md
+
+
+metrics = ["loss", "accuracy"]
+metric_dicts = {metric_name: metric_vals_dict(metric_name) for metric_name in metrics}
+
+
+def plot_metric(metric_name):
+ metric_dict = metric_dicts[metric_name]
+
+ fig, ax = plt.subplots()
+ k = 0
+
+ linewidth = 2 if with_mle else 3
+
+ for method_name, method_dict in metric_dict.items():
+ if len(method_dict) == 1:
+ mn = np.mean(method_dict["NA"])
+ sds = np.std(method_dict["NA"])
+ ax.fill_between(
+ [0.0, temperatures[-1] * 10],
+ [mn - sds] * 2,
+ [mn + sds] * 2,
+ color=colour_dict[method_name],
+ alpha=0.2,
+ zorder=0,
+ )
+ ax.axhline(
+ np.mean(method_dict["NA"]),
+ label=method_name,
+ c=colour_dict[method_name],
+ linestyle=line_styles[k],
+ linewidth=linewidth,
+ zorder=1,
+ )
+
+ k += 1
+
+ else:
+ mns = np.array([np.mean(method_dict[temp]) for temp in method_dict.keys()])
+ sds = np.array([np.std(method_dict[temp]) for temp in method_dict.keys()])
+
+ ax.fill_between(
+ temperatures,
+ mns - sds,
+ mns + sds,
+ color=colour_dict[method_name],
+ alpha=0.2,
+ zorder=0,
+ )
+ ax.plot(
+ temperatures,
+ mns,
+ label=method_name,
+ marker="o",
+ markersize=10,
+ c=colour_dict[method_name],
+ linewidth=linewidth,
+ zorder=1,
+ )
+
+ if metric_name in ylims:
+ ax.set_ylim(ylims[metric_name])
+
+ ax.set_xlim(temperatures[0] * 0.9, temperatures[-1] * 1.1)
+
+ ax.set_xscale("log")
+
+ fontsize = 18
+
+ ax.set_xticks(temperatures)
+ ax.tick_params(axis="both", which="major", labelsize=fontsize * 0.75)
+
+ # Change xtick labels to scalar format
+ ax.xaxis.set_major_formatter(FormatStrFormatter("%g"))
+
+ ax.set_xlabel("Temperature", fontsize=fontsize)
+ ax.set_ylabel("Test " + metric_name.title(), fontsize=fontsize)
+
+ leg_fontsize = fontsize * 0.75 if with_mle else fontsize
+ ax.legend(
+ frameon=True,
+ framealpha=1.0,
+ facecolor="white",
+ edgecolor="white",
+ fontsize=leg_fontsize,
+ )
+ fig.tight_layout()
+
+ save_dir = (
+ f"examples/imdb/figures/{save_name}_{metric_name}_with_mle.png"
+ if with_mle
+ else f"examples/imdb/figures/{save_name}_{metric_name}.png"
+ )
+ fig.savefig(save_dir, dpi=400)
+ plt.close()
+
+
+plot_metric("loss")
+plot_metric("accuracy")
diff --git a/examples/imdb/test.py b/examples/imdb/test.py
new file mode 100644
index 00000000..65aa1d0c
--- /dev/null
+++ b/examples/imdb/test.py
@@ -0,0 +1,117 @@
+import argparse
+import pickle
+import importlib
+from tqdm import tqdm
+import torch
+from torch.distributions import Categorical
+from optree import tree_map
+import os
+
+from examples.imdb.model import CNNLSTM
+from examples.imdb.data import load_imdb_dataset
+from examples.imdb.utils import log_metrics
+
+# Get config path and device from user
+parser = argparse.ArgumentParser()
+parser.add_argument("--config", type=str)
+parser.add_argument("--device", default="cpu", type=str)
+parser.add_argument("--seed", default=42, type=int)
+parser.add_argument("--temperature", default=None, type=float)
+args = parser.parse_args()
+
+
+torch.manual_seed(args.seed)
+
+# Import configuration
+config = importlib.import_module(args.config.replace("/", ".").replace(".py", ""))
+config_dir = os.path.dirname(args.config)
+save_dir = (
+ config_dir
+ if args.temperature is None
+ else config_dir + f"_temp{str(args.temperature).replace('.', '-')}"
+)
+if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+# Load data
+_, test_dataloader = load_imdb_dataset()
+
+# Load model
+model = CNNLSTM(num_classes=2)
+model.to(args.device)
+
+# Load state
+state = pickle.load(open(config_dir + "/state.pkl", "rb"))
+state = tree_map(lambda x: x.to(args.device), state)
+
+# Dictionary containing forward functions
+forward_dict = config.forward_dict
+
+
+def test_metrics(logits, labels):
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ probs = torch.where(probs < 1e-5, 1e-5, probs)
+ probs /= probs.sum(dim=-1, keepdim=True)
+ expected_probs = probs.mean(dim=1)
+
+ logits = torch.log(probs)
+ expected_logits = torch.log(expected_probs)
+
+ loss = -Categorical(logits=expected_logits).log_prob(labels)
+ accuracy = (expected_probs.argmax(dim=-1) == labels).float()
+
+ total_uncertainty = -(torch.log(expected_probs) * expected_probs).mean(1)
+ aleatoric_uncertainty = -(torch.log(probs) * probs).mean(2).mean(1)
+ epistemic_uncertainty = total_uncertainty - aleatoric_uncertainty
+
+ return {
+ "loss": loss,
+ "accuracy": accuracy,
+ "total_uncertainty": total_uncertainty,
+ "aleatoric_uncertainty": aleatoric_uncertainty,
+ "epistemic_uncertainty": epistemic_uncertainty,
+ }
+
+
+# Run through test data
+num_batches = len(test_dataloader)
+metrics = [
+ "loss",
+ "accuracy",
+ "total_uncertainty",
+ "aleatoric_uncertainty",
+ "epistemic_uncertainty",
+]
+
+log_dict_forward = {k: {m: [] for m in metrics} for k in forward_dict.keys()}
+
+for batch in tqdm(test_dataloader):
+ with torch.no_grad():
+ batch = tree_map(lambda x: x.to(args.device), batch)
+ labels = batch[1]
+
+ forward_logits = {
+ k: v(model, state, batch)
+ if args.temperature is None
+ else v(model, state, batch, args.temperature)
+ for k, v in forward_dict.items()
+ }
+
+ forward_metrics = {
+ k: test_metrics(logits, labels) for k, logits in forward_logits.items()
+ }
+
+ for forward_k in log_dict_forward.keys():
+ for metric_k in metrics:
+ log_dict_forward[forward_k][metric_k] += forward_metrics[forward_k][
+ metric_k
+ ].tolist()
+
+
+for forward_k, metric_dict in log_dict_forward.items():
+ log_metrics(
+ metric_dict,
+ save_dir,
+ file_name="test_" + forward_k,
+ plot=False,
+ )
diff --git a/examples/imdb/test_runner.sh b/examples/imdb/test_runner.sh
new file mode 100644
index 00000000..3fefc863
--- /dev/null
+++ b/examples/imdb/test_runner.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+
+device="cuda:1"
+
+# List of temperature params
+temperatures=(0.03 0.1 0.3 1.0 3.0)
+seeds=$(seq 1 5)
+
+# Base directories
+base_configs=(
+ # "examples/imdb/results/vi"
+ # "examples/imdb/results/sghmc_serial"
+ "examples/imdb/results/sghmc_parallel"
+)
+
+for temperature in "${temperatures[@]}"; do
+ for seed in $seeds; do
+ for base_config in "${base_configs[@]}"; do
+ config_file="${base_config}_seed${seed}_temp${temperature//./-}/config.py"
+ echo "Executing $config_file"
+ PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \
+ --device $device
+ done
+ done
+done
+
+
+
+# # Laplace
+
+# temperatures=(0.03 0.1 0.3 1.0 3.0)
+# seeds=$(seq 2)
+
+
+# # Base directories
+# base_configs=(
+# "examples/imdb/results/laplace_fisher"
+# # "examples/imdb/results/laplace_ggn"
+# )
+
+# for temperature in "${temperatures[@]}"; do
+# for seed in $seeds; do
+# for base_config in "${base_configs[@]}"; do
+# config_file="${base_config}_seed${seed}/config.py"
+# echo "Executing $config_file"
+# PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \
+# --device $device --temperature $temperature
+# done
+# done
+# done
+
+
+# # MLE and MAP
+# # List of temperature params
+# seeds=$(seq 1 5)
+
+# # Base directories
+# base_configs=(
+# "examples/imdb/results/mle"
+# "examples/imdb/results/map"
+# )
+
+# for seed in $seeds; do
+# for base_config in "${base_configs[@]}"; do
+# config_file="${base_config}_seed${seed}/config.py"
+# echo "Executing $config_file"
+# PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \
+# --device $device
+# done
+# done
\ No newline at end of file
diff --git a/examples/imdb/train.py b/examples/imdb/train.py
new file mode 100644
index 00000000..1382ad85
--- /dev/null
+++ b/examples/imdb/train.py
@@ -0,0 +1,149 @@
+import os
+import argparse
+import pickle
+import importlib
+from tqdm import tqdm
+from optree import tree_map
+import torch
+import posteriors
+
+from examples.imdb.model import CNNLSTM
+from examples.imdb.data import load_imdb_dataset
+from examples.imdb import utils
+
+
+# Get args from user
+parser = argparse.ArgumentParser()
+parser.add_argument("--config", type=str)
+parser.add_argument("--device", default="cpu", type=str)
+parser.add_argument("--seed", default=42, type=int)
+parser.add_argument("--epochs", default=30, type=int)
+parser.add_argument("--temperature", default=1.0, type=float)
+args = parser.parse_args()
+
+
+# Import configuration
+config = importlib.import_module(args.config.replace("/", ".").replace(".py", ""))
+
+
+# Set seed
+if args.seed != 42:
+ config.save_dir += f"_seed{args.seed}"
+ if config.params_dir is not None:
+ config.params_dir += f"_seed{args.seed}/state.pkl"
+else:
+ args.seed = 42
+ config.params_dir += "/state.pkl"
+torch.manual_seed(args.seed)
+
+# Load model
+model = CNNLSTM(num_classes=2)
+model.to(args.device)
+
+if config.params_dir is not None:
+ with open(config.params_dir, "rb") as f:
+ print(config.params_dir)
+ state = pickle.load(f)
+ model.load_state_dict(state.params)
+
+# Load data
+train_dataloader, test_dataloader = load_imdb_dataset(batch_size=config.batch_size)
+num_data = len(train_dataloader.dataset)
+
+
+# Set temperature
+if "temperature" in config.config_args and config.config_args["temperature"] is None:
+ config.config_args["temperature"] = args.temperature / num_data
+ temp_str = str(args.temperature).replace(".", "-")
+ config.save_dir += f"_temp{temp_str}"
+
+
+# Create save directory if it does not exist
+if not os.path.exists(config.save_dir):
+ os.makedirs(config.save_dir)
+
+# Save config
+utils.save_config(args, config.save_dir)
+print(f"Config saved to {config.save_dir}")
+
+# Extract model parameters
+params = dict(model.named_parameters())
+num_params = posteriors.tree_size(params).item()
+print(f"Number of parameters: {num_params/1e6:.3f}M")
+
+
+# Define log posterior
+def forward(p, batch):
+ x, y = batch
+ logits = torch.func.functional_call(model, p, x)
+ loss = torch.nn.functional.cross_entropy(logits, y)
+ return logits, (loss, logits)
+
+
+def outer_log_lik(logits, batch):
+ _, y = batch
+ return -torch.nn.functional.cross_entropy(logits, y, reduction="sum")
+
+
+def log_posterior(p, batch):
+ x, y = batch
+ logits = torch.func.functional_call(model, p, x)
+ loss = torch.nn.functional.cross_entropy(logits, y)
+ log_post = (
+ -loss
+ + posteriors.diag_normal_log_prob(p, sd_diag=config.prior_sd, normalize=False)
+ / num_data
+ )
+ return log_post, (loss, logits)
+
+
+# Build transform
+if config.method == posteriors.laplace.diag_ggn:
+ transform = config.method.build(forward, outer_log_lik, **config.config_args)
+else:
+ transform = config.method.build(log_posterior, **config.config_args)
+
+# Initialize state
+state = transform.init(params)
+
+# Train
+i = j = 0
+num_batches = len(train_dataloader)
+log_dict = {k: [] for k in config.log_metrics.keys()} | {"loss": []}
+log_bar = tqdm(total=0, position=1, bar_format="{desc}")
+for epoch in range(args.epochs):
+ for batch in tqdm(
+ train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0
+ ):
+ batch = tree_map(lambda x: x.to(args.device), batch)
+ state = transform.update(state, batch)
+
+ # Update metrics
+ log_dict = utils.append_metrics(log_dict, state, config.log_metrics)
+ log_bar.set_description_str(
+ f"{config.display_metric}: {log_dict[config.display_metric][-1]:.2f}"
+ )
+
+ # Log
+ i += 1
+ if i % config.log_frequency == 0 or i % num_batches == 0:
+ utils.log_metrics(
+ log_dict,
+ config.save_dir,
+ window=config.log_window,
+ file_name="training",
+ )
+
+ # Save sequential state if desired
+ if (
+ config.save_frequency is not None
+ and (i - config.burnin) >= 0
+ and (i - config.burnin) % config.save_frequency == 0
+ ):
+ with open(f"{config.save_dir}/state_{j}.pkl", "wb") as f:
+ pickle.dump(state, f)
+ j += 1
+
+# Save final state
+with open(f"{config.save_dir}/state.pkl", "wb") as f:
+ pickle.dump(state, f)
diff --git a/examples/imdb/train_runner.sh b/examples/imdb/train_runner.sh
new file mode 100644
index 00000000..a9772f95
--- /dev/null
+++ b/examples/imdb/train_runner.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+device="cuda:0"
+
+# temperatures=(0.03 0.1 0.3 1.0 3.0)
+# config="examples/imdb/configs/vi.py"
+# epochs=30
+# seeds=$(seq 1 5)
+
+# temperatures=(0.03 0.1 0.3 1.0 3.0)
+# config="examples/imdb/configs/sghmc_serial.py"
+# epochs=60
+# seeds=$(seq 1 5)
+
+temperatures=(0.03 0.1 0.3 1.0 3.0)
+config="examples/imdb/configs/sghmc_parallel.py"
+epochs=30
+seeds=$(seq 1 35)
+
+# temperatures=(1.0)
+# config="examples/imdb/configs/map.py"
+# epochs=30
+# seeds=$(seq 1 5)
+
+# temperatures=(1.0)
+# config="examples/imdb/configs/mle.py"
+# epochs=30
+# seeds=$(seq 1 5)
+
+# temperatures=(1.0)
+# config="examples/imdb/configs/laplace_fisher.py"
+# epochs=1
+# seeds=$(seq 1 5)
+
+# temperatures=(1.0)
+# config="examples/imdb/configs/laplace_ggn.py"
+# epochs=1
+# seeds=$(seq 1 5)
+
+for seed in $seeds
+do
+ for temp in "${temperatures[@]}"
+ do
+ PYTHONPATH=. python examples/imdb/train.py --config $config --device $device \
+ --temperature $temp --epochs $epochs --seed $seed
+ done
+done
\ No newline at end of file
diff --git a/examples/imdb/utils.py b/examples/imdb/utils.py
new file mode 100644
index 00000000..b5b14813
--- /dev/null
+++ b/examples/imdb/utils.py
@@ -0,0 +1,45 @@
+import json
+import matplotlib.pyplot as plt
+import numpy as np
+import shutil
+
+
+# Save config
+def save_config(args, save_dir):
+ shutil.copy(args.config, save_dir + "/config.py")
+ json.dump(vars(args), open(f"{save_dir}/args.json", "w"))
+
+
+# Function to calculate moving average and accompanying x-axis values
+def moving_average(x, w):
+ w_check = 1 if w >= len(x) else w
+ y = np.convolve(x, np.ones(w_check), "valid") / w_check
+ return range(w_check - 1, len(x)), y
+
+
+# Function to log and plot metrics
+def log_metrics(log_dict, save_dir, window=1, file_name="metrics", plot=True):
+ save_file = f"{save_dir}/{file_name}.json"
+
+ with open(save_file, "w") as f:
+ json.dump(log_dict, f)
+
+ if plot:
+ for k, v in log_dict.items():
+ fig, ax = plt.subplots()
+ x, y = moving_average(v, window)
+ ax.plot(x, y, label=k, alpha=0.7)
+ ax.legend()
+ ax.set_xlabel("Iteration")
+ fig.tight_layout()
+ fig.savefig(f"{save_dir}/{file_name}_{k}.png", dpi=200)
+ plt.close(fig)
+
+
+# Function to append metrics to log_dict
+def append_metrics(log_dict, state, config_dict):
+ for k, v in config_dict.items():
+ log_dict[k].append(getattr(state, v).mean().item())
+
+ log_dict["loss"].append(state.aux[0].mean().item())
+ return log_dict