From 57c46cd8762733c604e0b13b01357004e512031f Mon Sep 17 00:00:00 2001 From: SamDuffield Date: Fri, 10 May 2024 14:30:23 +0000 Subject: [PATCH 1/6] Add code for IMDB cold posterior --- examples/imdb/combine_states.py | 66 +++++++++++ examples/imdb/configs/laplace_fisher.py | 74 ++++++++++++ examples/imdb/configs/laplace_ggn.py | 74 ++++++++++++ examples/imdb/configs/map.py | 36 ++++++ examples/imdb/configs/mle.py | 36 ++++++ examples/imdb/configs/sghmc_parallel.py | 42 +++++++ examples/imdb/configs/sghmc_serial.py | 43 +++++++ examples/imdb/configs/vi.py | 81 +++++++++++++ examples/imdb/data.py | 27 +++++ examples/imdb/lstm.py | 129 ++++++++++++++++++++ examples/imdb/model.py | 72 ++++++++++++ examples/imdb/plot.py | 149 ++++++++++++++++++++++++ examples/imdb/test.py | 104 +++++++++++++++++ examples/imdb/test_runner.sh | 25 ++++ examples/imdb/train.py | 145 +++++++++++++++++++++++ examples/imdb/train_runner.sh | 29 +++++ examples/imdb/utils.py | 45 +++++++ 17 files changed, 1177 insertions(+) create mode 100644 examples/imdb/combine_states.py create mode 100644 examples/imdb/configs/laplace_fisher.py create mode 100644 examples/imdb/configs/laplace_ggn.py create mode 100644 examples/imdb/configs/map.py create mode 100644 examples/imdb/configs/mle.py create mode 100644 examples/imdb/configs/sghmc_parallel.py create mode 100644 examples/imdb/configs/sghmc_serial.py create mode 100644 examples/imdb/configs/vi.py create mode 100644 examples/imdb/data.py create mode 100644 examples/imdb/lstm.py create mode 100644 examples/imdb/model.py create mode 100644 examples/imdb/plot.py create mode 100644 examples/imdb/test.py create mode 100644 examples/imdb/test_runner.sh create mode 100644 examples/imdb/train.py create mode 100644 examples/imdb/train_runner.sh create mode 100644 examples/imdb/utils.py diff --git a/examples/imdb/combine_states.py b/examples/imdb/combine_states.py new file mode 100644 index 00000000..faa98b4e --- /dev/null +++ b/examples/imdb/combine_states.py @@ -0,0 +1,66 @@ +import pickle +import torch +import os +from optree import tree_map +import shutil + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] + + +# # SGHMC Serial +# base = "examples/imdb/results/sghmc_serial" +# seeds = [None] + + +# SGHMC Parallel +base = "examples/imdb/results/sghmc_parallel/sghmc_parallel" +seeds = list(range(1, 21)) + + +for temp in temperatures: + temp_str = str(temp).replace(".", "-") + + load_paths = [] + + for seed in seeds: + spec_base = base + + if seed is not None: + spec_base += f"_seed{seed}" + + spec_base += f"_temp{temp_str}/" + + match_str = ( + "state_" if seed is None else "state" + ) # don't need last save for serial + + load_paths += [ + spec_base + file for file in os.listdir(spec_base) if match_str in file + ] + + save_dir = base + f"_temp{temp_str}/" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + shutil.copy(spec_base + "config.py", save_dir + "/config.py") + + with open(f"{save_dir}/paths.txt", "w") as f: + f.write("\n".join(load_paths)) + + # Load states + states = [pickle.load(open(d, "rb")) for d in load_paths] + + # Delete auxiliary info + for s in states: + del s.aux + + # Move states to cpu + states = tree_map(lambda x: x.detach().to("cpu"), states) + + # Combine states + combined_state = tree_map(lambda *x: torch.stack(x), *states) + + # Save state + with open(f"{save_dir}state.pkl", "wb") as f: + pickle.dump(combined_state, f) diff --git a/examples/imdb/configs/laplace_fisher.py b/examples/imdb/configs/laplace_fisher.py new file mode 100644 index 00000000..9fc95125 --- /dev/null +++ b/examples/imdb/configs/laplace_fisher.py @@ -0,0 +1,74 @@ +import torch +import posteriors +from optree import tree_map + +temperature = 1.0 + +name = "laplace_fisher" +save_dir = "examples/imdb/results/" + name +params_dir = "examples/imdb/results/map/state.pkl" # directory to load state containing initialisation params + +prior_sd = torch.inf +batch_size = 32 +burnin = None +save_frequency = None + +method = posteriors.laplace.diag_fisher +config_args = {} # arguments for method.build (aside from log_posterior) +log_metrics = {} # dict containing names of metrics as keys and their paths in state as values +display_metric = "loss" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 40 # window size for moving average + +n_test_samples = 50 +n_linearised_test_samples = 10000 +epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag + + +def to_sd_diag(state): + return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) + + +def forward(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + sampled_params = posteriors.diag_normal_sample( + state.params, sd_diag, (n_test_samples,) + ) + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( + 0, 1 + ) + return logits + + +def forward_linearized(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + def model_func_with_aux(p, x): + return torch.func.functional_call(model, p, x), torch.tensor([]) + + lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( + model_func_with_aux, + state.params, + x, + sd_diag, + ) + + samps = torch.randn( + lin_mean.shape[0], + n_linearised_test_samples, + lin_mean.shape[1], + device=lin_mean.device, + ) + lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) + return lin_logits + + +forward_dict = {"Laplace EF": forward, "Laplace EF Linearized": forward_linearized} diff --git a/examples/imdb/configs/laplace_ggn.py b/examples/imdb/configs/laplace_ggn.py new file mode 100644 index 00000000..92a8628b --- /dev/null +++ b/examples/imdb/configs/laplace_ggn.py @@ -0,0 +1,74 @@ +import torch +import posteriors +from optree import tree_map + +temperature = 1.0 + +name = "laplace_ggn" +save_dir = "examples/imdb/results/" + name +params_dir = "examples/imdb/results/map/state.pkl" # directory to load state containing initialisation params + +prior_sd = torch.inf +batch_size = 32 +burnin = None +save_frequency = None + +method = posteriors.laplace.diag_ggn +config_args = {} # arguments for method.build (aside from log_posterior) +log_metrics = {} # dict containing names of metrics as keys and their paths in state as values +display_metric = "loss" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 40 # window size for moving average + +n_test_samples = 50 +n_linearised_test_samples = 10000 +epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag + + +def to_sd_diag(state): + return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) + + +def forward(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + sampled_params = posteriors.diag_normal_sample( + state.params, sd_diag, (n_test_samples,) + ) + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( + 0, 1 + ) + return logits + + +def forward_linearized(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + def model_func_with_aux(p, x): + return torch.func.functional_call(model, p, x), torch.tensor([]) + + lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( + model_func_with_aux, + state.params, + x, + sd_diag, + ) + + samps = torch.randn( + lin_mean.shape[0], + n_linearised_test_samples, + lin_mean.shape[1], + device=lin_mean.device, + ) + lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) + return lin_logits + + +forward_dict = {"Laplace GGN": forward, "Laplace GGN Linearized": forward_linearized} diff --git a/examples/imdb/configs/map.py b/examples/imdb/configs/map.py new file mode 100644 index 00000000..02813e11 --- /dev/null +++ b/examples/imdb/configs/map.py @@ -0,0 +1,36 @@ +import posteriors +import torchopt +import torch + +name = "map" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = None +save_frequency = None + + +method = posteriors.torchopt +config_args = { + "optimizer": torchopt.adamw(lr=1e-3, maximize=True) +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "loss", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + logits = torch.func.functional_call(model, state.params, x) + return logits.unsqueeze(1) + + +forward_dict = {"MAP": forward} diff --git a/examples/imdb/configs/mle.py b/examples/imdb/configs/mle.py new file mode 100644 index 00000000..bfb9e5a1 --- /dev/null +++ b/examples/imdb/configs/mle.py @@ -0,0 +1,36 @@ +import posteriors +import torchopt +import torch + +name = "mle" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = torch.inf +batch_size = 32 +burnin = None +save_frequency = None + + +method = posteriors.torchopt +config_args = { + "optimizer": torchopt.adamw(lr=1e-3, maximize=True) +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "loss", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + logits = torch.func.functional_call(model, state.params, x) + return logits.unsqueeze(1) + + +forward_dict = {"MLE": forward} diff --git a/examples/imdb/configs/sghmc_parallel.py b/examples/imdb/configs/sghmc_parallel.py new file mode 100644 index 00000000..e7ed818c --- /dev/null +++ b/examples/imdb/configs/sghmc_parallel.py @@ -0,0 +1,42 @@ +import posteriors +import torch + +name = "sghmc_parallel" +save_dir = "examples/imdb/results/sghmc_parallel/" + name +params_dir = None # directory to load state containing initialisation params + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = None +save_frequency = None + +lr = 1e-1 + +method = posteriors.sgmcmc.sghmc +config_args = { + "lr": lr, + "alpha": 1.0, + "beta": 0.0, + "temperature": None, # None temperature gets set by train.py + "momenta": 0.0, +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "log_posterior", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(state.params, x).transpose(0, 1) + return logits + + +forward_dict = {"Parallel SGHMC": forward} diff --git a/examples/imdb/configs/sghmc_serial.py b/examples/imdb/configs/sghmc_serial.py new file mode 100644 index 00000000..aec7e4e8 --- /dev/null +++ b/examples/imdb/configs/sghmc_serial.py @@ -0,0 +1,43 @@ +import posteriors +import torch + +name = "sghmc_serial" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = 20000 +save_frequency = 1000 + +lr = 1e-1 + +method = posteriors.sgmcmc.sghmc +config_args = { + "lr": lr, + "alpha": 1.0, + "beta": 0.0, + "temperature": None, # None temperature gets set by train.py + "momenta": 0.0, +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "log_post": "log_posterior", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "log_post" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + + +def forward(model, state, batch): + x, _ = batch + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(state.params, x).transpose(0, 1) + return logits + + +forward_dict = {"Serial SGHMC": forward} diff --git a/examples/imdb/configs/vi.py b/examples/imdb/configs/vi.py new file mode 100644 index 00000000..082ddadb --- /dev/null +++ b/examples/imdb/configs/vi.py @@ -0,0 +1,81 @@ +import posteriors +import torchopt +import torch +from optree import tree_map + +name = "vi" +save_dir = "examples/imdb/results/" + name +params_dir = None # directory to load state containing initialisation params + + +prior_sd = (1 / 40) ** 0.5 +batch_size = 32 +burnin = None +save_frequency = None + +method = posteriors.vi.diag +config_args = { + "optimizer": torchopt.adamw(lr=1e-3), + "temperature": None, # None temperature gets set by train.py + "n_samples": 1, + "stl": True, + "init_log_sds": -3, +} # arguments for method.build (aside from log_posterior) +log_metrics = { + "nelbo": "nelbo", +} # dict containing names of metrics as keys and their paths in state as values +display_metric = "nelbo" # metric to display in tqdm progress bar + +log_frequency = 100 # frequency at which to log metrics +log_window = 30 # window size for moving average + +n_test_samples = 50 +n_linearised_test_samples = 10000 + + +def to_sd_diag(state): + return tree_map(lambda x: x.exp(), state.log_sd_diag) + + +def forward(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + sampled_params = posteriors.diag_normal_sample( + state.params, sd_diag, (n_test_samples,) + ) + + def model_func(p, x): + return torch.func.functional_call(model, p, x) + + logits = torch.vmap(model_func, in_dims=(0, None))(sampled_params, x).transpose( + 0, 1 + ) + return logits + + +def forward_linearized(model, state, batch): + x, _ = batch + sd_diag = to_sd_diag(state) + + def model_func_with_aux(p, x): + return torch.func.functional_call(model, p, x), torch.tensor([]) + + lin_mean, lin_chol, _ = posteriors.linearized_forward_diag( + model_func_with_aux, + state.params, + x, + sd_diag, + ) + + samps = torch.randn( + lin_mean.shape[0], + n_linearised_test_samples, + lin_mean.shape[1], + device=lin_mean.device, + ) + lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2) + return lin_logits + + +forward_dict = {"VI": forward, "VI Linearized": forward_linearized} diff --git a/examples/imdb/data.py b/examples/imdb/data.py new file mode 100644 index 00000000..339655c5 --- /dev/null +++ b/examples/imdb/data.py @@ -0,0 +1,27 @@ +import torch +from torch.utils.data import DataLoader, TensorDataset +from keras.datasets import imdb +from keras.preprocessing.sequence import pad_sequences + + +def load_imdb_dataset(batch_size=32, max_features=20000, max_len=100): + # Load and pad IMDB dataset + (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) + x_train = pad_sequences(x_train, maxlen=max_len) + x_test = pad_sequences(x_test, maxlen=max_len) + + # Convert data to PyTorch tensors + train_data = torch.tensor(x_train, dtype=torch.long) + train_labels = torch.tensor(y_train, dtype=torch.long) + test_data = torch.tensor(x_test, dtype=torch.long) + test_labels = torch.tensor(y_test, dtype=torch.long) + + # Create Tensor datasets + train_dataset = TensorDataset(train_data, train_labels) + test_dataset = TensorDataset(test_data, test_labels) + + # Create DataLoaders + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + test_dataloader = DataLoader(test_dataset, batch_size=batch_size) + + return train_dataloader, test_dataloader diff --git a/examples/imdb/lstm.py b/examples/imdb/lstm.py new file mode 100644 index 00000000..7804cbc4 --- /dev/null +++ b/examples/imdb/lstm.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn + + +class CustomLSTMCell(nn.Module): + def __init__(self, input_size, hidden_size): + super(CustomLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size) + self.input_gate = nn.Linear(input_size + hidden_size, hidden_size) + self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size) + self.output_gate = nn.Linear(input_size + hidden_size, hidden_size) + + def forward(self, input, states): + h_prev, c_prev = states + combined = torch.cat((input, h_prev), 1) + + f_t = torch.sigmoid(self.forget_gate(combined)) + i_t = torch.sigmoid(self.input_gate(combined)) + c_tilde = torch.tanh(self.cell_gate(combined)) + c_t = f_t * c_prev + i_t * c_tilde + o_t = torch.sigmoid(self.output_gate(combined)) + h_t = o_t * torch.tanh(c_t) + + return h_t, c_t + + +class CustomLSTM(nn.Module): + def __init__(self, input_size, hidden_size, batch_first=False): + super(CustomLSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.batch_first = batch_first + self.lstm_cell = CustomLSTMCell(input_size, hidden_size) + + def forward(self, input, initial_states=None): + if initial_states is None: + initial_h = torch.zeros( + 1, + input.size(0 if self.batch_first else 1), + self.hidden_size, + device=input.device, + ) + initial_c = torch.zeros( + 1, + input.size(0 if self.batch_first else 1), + self.hidden_size, + device=input.device, + ) + else: + initial_h, initial_c = initial_states + + # Ensure we are working with single layer, single direction states + initial_h = initial_h.squeeze(0) + initial_c = initial_c.squeeze(0) + + if self.batch_first: + input = input.transpose( + 0, 1 + ) # Convert (batch, seq_len, feature) to (seq_len, batch, feature) + + outputs = [] + h_t, c_t = initial_h, initial_c + + for i in range( + input.shape[0] + ): # input is expected to be (seq_len, batch, input_size) + h_t, c_t = self.lstm_cell(input[i], (h_t, c_t)) + outputs.append(h_t.unsqueeze(0)) + + outputs = torch.cat(outputs, 0) + + if self.batch_first: + outputs = outputs.transpose( + 0, 1 + ) # Convert back to (batch, seq_len, feature) + + return outputs, (h_t, c_t) + + +# Test equivalence +def test_lstm_equivalence(): + input_size = 10 + hidden_size = 20 + seq_len = 5 + batch_size = 1 + + # Initialize both LSTMs + torch_lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + custom_lstm = CustomLSTM(input_size, hidden_size, batch_first=True) + + # Manually setting the same weights and biases + with torch.no_grad(): + # Copy weights and biases from torch LSTM to custom LSTM + gates = ["input_gate", "forget_gate", "cell_gate", "output_gate"] + + for idx, gate in enumerate(gates): + start = idx * hidden_size + end = (idx + 1) * hidden_size + + getattr(custom_lstm.lstm_cell, gate).weight.data[:, :input_size].copy_( + torch_lstm.weight_ih_l0[start:end] + ) + getattr(custom_lstm.lstm_cell, gate).weight.data[:, input_size:].copy_( + torch_lstm.weight_hh_l0[start:end] + ) + getattr(custom_lstm.lstm_cell, gate).bias.data.copy_( + torch_lstm.bias_ih_l0[start:end] + torch_lstm.bias_hh_l0[start:end] + ) + # Dummy input + inputs = torch.randn(batch_size, seq_len, input_size) + + # Custom LSTM forward pass + custom_outputs, (custom_hn, custom_cn) = custom_lstm(inputs) + + # Torch LSTM forward pass + torch_outputs, (torch_hn, torch_cn) = torch_lstm(inputs) + + # Check outputs and final hidden and cell states + assert torch.allclose(custom_outputs, torch_outputs, atol=1e-6), "Output mismatch" + assert torch.allclose(custom_hn, torch_hn), "Hidden state mismatch" + assert torch.allclose(custom_cn, torch_cn), "Cell state mismatch" + + print("Test passed: Custom LSTM and torch.nn.LSTM outputs are equivalent!") + + +if __name__ == "__main__": + test_lstm_equivalence() diff --git a/examples/imdb/model.py b/examples/imdb/model.py new file mode 100644 index 00000000..651d8894 --- /dev/null +++ b/examples/imdb/model.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# torch.nn.LSTM does not work well with torch.func https://github.com/pytorch/pytorch/issues/105982 +# so use custom simplfied LSTM instead +from examples.imdb.lstm import CustomLSTM + + +class CNNLSTM(nn.Module): + def __init__( + self, + num_classes, + max_features=20000, + embedding_size=128, + cell_size=128, + num_filters=64, + kernel_size=5, + pool_size=4, + use_swish=False, + use_maxpool=True, + ): + super(CNNLSTM, self).__init__() + self.embedding = nn.Embedding( + num_embeddings=max_features, embedding_dim=embedding_size + ) + self.conv1d = nn.Conv1d( + in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + ) + self.use_swish = use_swish + self.use_maxpool = use_maxpool + if use_maxpool: + self.maxpool = nn.MaxPool1d(kernel_size=pool_size, stride=pool_size) + self.lstm = CustomLSTM( + input_size=num_filters, hidden_size=cell_size, batch_first=True + ) + + self.fc = nn.Linear(in_features=cell_size, out_features=num_classes) + + def forward(self, x): + # Embedding + x = self.embedding(x) # Shape: [batch_size, seq_length, embedding_size] + x = x.permute( + 0, 2, 1 + ) # Shape: [batch_size, embedding_size, seq_length] to match Conv1d input + + # Convolution + x = self.conv1d(x) + if self.use_swish: + x = x * torch.sigmoid(x) # Swish activation function + else: + x = F.relu(x) + + # Pooling + if self.use_maxpool: + x = self.maxpool(x) + + # Reshape for LSTM + x = x.permute(0, 2, 1) # Shape: [batch_size, seq_length, num_filters] + + # LSTM + output, _ = self.lstm(x) + + # Take the last sequence output + last_output = output[:, -1, :] + + # Fully connected layer + logits = self.fc(last_output) + + return logits diff --git a/examples/imdb/plot.py b/examples/imdb/plot.py new file mode 100644 index 00000000..45489e3c --- /dev/null +++ b/examples/imdb/plot.py @@ -0,0 +1,149 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.ticker import FormatStrFormatter +import json +import os + +plt.rcParams["font.family"] = "Times New Roman" + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] +temp_strs = [str(temp).replace(".", "-") for temp in temperatures] + + +ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)} + + +# [base directory, tempered bool, colors] + +bases = [ + # ["examples/imdb/results/mle", False, ["grey"]], + ["examples/imdb/results/map", False, ["grey"]], + ["examples/imdb/results/laplace_fisher", True, ["royalblue", "deepskyblue"]], + ["examples/imdb/results/laplace_ggn", True, ["purple", "mediumvioletred"]], +] + +# bases = [ +# # ["examples/imdb/results/mle", False, ["grey"]], +# ["examples/imdb/results/map", False, ["grey"]], +# ["examples/imdb/results/vi", True, ["forestgreen", "darkkhaki"]], +# ] + +# bases = [ +# # ["examples/imdb/results/mle", False, ["grey"]], +# ["examples/imdb/results/map", False, ["grey"]], +# ["examples/imdb/results/sghmc_serial", True, ["firebrick"]], +# ["examples/imdb/results/sghmc_parallel/sghmc_parallel", True, ["tomato"]], +# ] + +save_name = bases[-1][0].split("/")[-1].split("_")[0] + + +test_dicts = {} +colour_dict = {} + +for base, tempered, colours in bases: + match_str = base + "_temp" + temp_strs[0] if tempered else base + versions = [file for file in os.listdir(match_str) if "test" in file] + versions = [v.strip(".json").split("_")[1] for v in versions] + versions.sort() + + for k, name in enumerate(versions): + colour_dict[name] = colours[k] + if tempered: + single_dict = { + temperatures[i]: json.load( + open(f"{base}_temp{temp_strs[i]}/test_{name}.json") + ) + for i in range(len(temperatures)) + } + else: + single_dict = {0.0: json.load(open(f"{base}/test_{name}.json"))} + + test_dicts |= {name: single_dict} + + +line_styles = ["--", ":", "-.", "-"] + + +def mean_dict(metric_name): + md = {} + + for method_name, method_dict in test_dicts.items(): + if len(method_dict) == 1: + val = np.mean(method_dict[0.0][metric_name], axis=0) + md |= {method_name: val} + + else: + plot_dict = { + temp: np.mean(vals[metric_name], axis=0) + for temp, vals in method_dict.items() + } + md |= {method_name: plot_dict} + + return md + + +metrics = list(list(list(test_dicts.values())[0].values())[0].keys()) + +metric_dicts = {metric_name: mean_dict(metric_name) for metric_name in metrics} + + +def plot_metric(metric_name): + metric_dict = metric_dicts[metric_name] + + fig, ax = plt.subplots() + k = 0 + + linewidth = 3 + + for method_name, method_dict in metric_dict.items(): + if not isinstance(method_dict, dict): + ax.axhline( + method_dict, + label=method_name, + c=colour_dict[method_name], + linestyle=line_styles[k], + linewidth=linewidth, + ) + k += 1 + + else: + ax.plot( + list(method_dict.keys()), + list(method_dict.values()), + label=method_name, + marker="o", + markersize=10, + c=colour_dict[method_name], + linewidth=linewidth, + ) + + if metric_name in ylims: + ax.set_ylim(ylims[metric_name]) + + ax.set_xscale("log") + + fontsize = 18 + + ax.set_xticks(temperatures) + ax.tick_params(axis="both", which="major", labelsize=fontsize * 0.75) + + # Change xtick labels to scalar format + ax.xaxis.set_major_formatter(FormatStrFormatter("%g")) + + ax.set_xlabel("Temperature", fontsize=fontsize) + ax.set_ylabel("Test " + metric_name.title(), fontsize=fontsize) + ax.legend( + frameon=True, + framealpha=1.0, + facecolor="white", + edgecolor="white", + fontsize=fontsize, + ) + fig.tight_layout() + fig.savefig(f"examples/imdb/results/{save_name}_{metric_name}.png", dpi=400) + + +plot_metric("loss") +plot_metric("accuracy") diff --git a/examples/imdb/test.py b/examples/imdb/test.py new file mode 100644 index 00000000..df9cbcd6 --- /dev/null +++ b/examples/imdb/test.py @@ -0,0 +1,104 @@ +import argparse +import pickle +import importlib +from tqdm import tqdm +import torch +from torch.distributions import Categorical +from optree import tree_map +import os + +from examples.imdb.model import CNNLSTM +from examples.imdb.data import load_imdb_dataset +from examples.imdb.utils import log_metrics + +# Get config path and device from user +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str) +parser.add_argument("--device", default="cpu", type=str) +parser.add_argument("--seed", default=42, type=int) +args = parser.parse_args() + + +torch.manual_seed(args.seed) + +# Import configuration +config = importlib.import_module(args.config.replace("/", ".").replace(".py", "")) +config_dir = os.path.dirname(args.config) + +# Load data +_, test_dataloader = load_imdb_dataset() + +# Load model +model = CNNLSTM(num_classes=2) +model.to(args.device) + +# Load state +state = pickle.load(open(config_dir + "/state.pkl", "rb")) +state = tree_map(lambda x: x.to(args.device), state) + +# Dictionary containing forward functions +forward_dict = config.forward_dict + + +def test_metrics(logits, labels): + probs = torch.nn.functional.softmax(logits, dim=-1) + probs = torch.where(probs < 1e-5, 1e-5, probs) + probs /= probs.sum(dim=-1, keepdim=True) + expected_probs = probs.mean(dim=1) + + logits = torch.log(probs) + expected_logits = torch.log(expected_probs) + + loss = -Categorical(logits=expected_logits).log_prob(labels) + accuracy = (expected_probs.argmax(dim=-1) == labels).float() + + total_uncertainty = -(torch.log(expected_probs) * expected_probs).mean(1) + aleatoric_uncertainty = -(torch.log(probs) * probs).mean(2).mean(1) + epistemic_uncertainty = total_uncertainty - aleatoric_uncertainty + + return { + "loss": loss, + "accuracy": accuracy, + "total_uncertainty": total_uncertainty, + "aleatoric_uncertainty": aleatoric_uncertainty, + "epistemic_uncertainty": epistemic_uncertainty, + } + + +# Run through test data +num_batches = len(test_dataloader) +metrics = [ + "loss", + "accuracy", + "total_uncertainty", + "aleatoric_uncertainty", + "epistemic_uncertainty", +] + +log_dict_forward = {k: {m: [] for m in metrics} for k in forward_dict.keys()} + +for batch in tqdm(test_dataloader): + with torch.no_grad(): + batch = tree_map(lambda x: x.to(args.device), batch) + labels = batch[1] + + forward_logits = {k: v(model, state, batch) for k, v in forward_dict.items()} + + forward_metrics = { + k: test_metrics(logits, labels) for k, logits in forward_logits.items() + } + + for forward_k in log_dict_forward.keys(): + for metric_k in metrics: + log_dict_forward[forward_k][metric_k] += forward_metrics[forward_k][ + metric_k + ].tolist() + + +for forward_k, metric_dict in log_dict_forward.items(): + log_metrics( + metric_dict, + config_dir, + file_name="test_" + forward_k, + plot=False, + ) diff --git a/examples/imdb/test_runner.sh b/examples/imdb/test_runner.sh new file mode 100644 index 00000000..4c216ed1 --- /dev/null +++ b/examples/imdb/test_runner.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +device="cuda:1" + +# List of temperature params +temperatures=(0.03 0.1 0.3 1.0 3.0) + +# Base directories +base_configs=( + "examples/imdb/results/laplace_fisher" + "examples/imdb/results/laplace_ggn" + "examples/imdb/results/vi" + "examples/imdb/results/sghmc_serial" + "examples/imdb/results/sghmc_parallel/sghmc_parallel" +) + + +for temperature in "${temperatures[@]}"; do + for base_config in "${base_configs[@]}"; do + config_file="${base_config}_temp${temperature//./-}/config.py" + echo "Executing $config_file" + PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ + --device $device + done +done diff --git a/examples/imdb/train.py b/examples/imdb/train.py new file mode 100644 index 00000000..fe4e281d --- /dev/null +++ b/examples/imdb/train.py @@ -0,0 +1,145 @@ +import os +import argparse +import pickle +import importlib +from tqdm import tqdm +from optree import tree_map +import torch +import posteriors + +from examples.imdb.model import CNNLSTM +from examples.imdb.data import load_imdb_dataset +from examples.imdb import utils + + +# Get args from user +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str) +parser.add_argument("--device", default="cpu", type=str) +parser.add_argument("--seed", default=42, type=int) +parser.add_argument("--epochs", default=30, type=int) +parser.add_argument("--temperature", default=1.0, type=float) +args = parser.parse_args() + + +# Import configuration +config = importlib.import_module(args.config.replace("/", ".").replace(".py", "")) + + +# Set seed +if args.seed != 42: + config.save_dir += f"_seed{args.seed}" +else: + args.seed = 42 +torch.manual_seed(args.seed) + +# Load model +model = CNNLSTM(num_classes=2) +model.to(args.device) + +if config.params_dir is not None: + with open(config.params_dir, "rb") as f: + state = pickle.load(f) + model.load_state_dict(state.params) + +# Load data +train_dataloader, test_dataloader = load_imdb_dataset(batch_size=config.batch_size) +num_data = len(train_dataloader.dataset) + + +# Set temperature +if "temperature" in config.config_args and config.config_args["temperature"] is None: + config.config_args["temperature"] = args.temperature / num_data + temp_str = str(args.temperature).replace(".", "-") + config.save_dir += f"_temp{temp_str}" + + +# Create save directory if it does not exist +if not os.path.exists(config.save_dir): + os.makedirs(config.save_dir) + +# Save config +utils.save_config(args, config.save_dir) +print(f"Config saved to {config.save_dir}") + +# Extract model parameters +params = dict(model.named_parameters()) +num_params = posteriors.tree_size(params).item() +print(f"Number of parameters: {num_params/1e6:.3f}M") + + +# Define log posterior +def forward(p, batch): + x, y = batch + logits = torch.func.functional_call(model, p, x) + loss = torch.nn.functional.cross_entropy(logits, y) + return logits, (loss, logits) + + +def outer_log_lik(logits, batch): + _, y = batch + return -torch.nn.functional.cross_entropy(logits, y, reduction="sum") + + +def log_posterior(p, batch): + x, y = batch + logits = torch.func.functional_call(model, p, x) + loss = torch.nn.functional.cross_entropy(logits, y) + log_post = ( + -loss + + posteriors.diag_normal_log_prob(p, sd_diag=config.prior_sd, normalize=False) + / num_data + ) + return log_post, (loss, logits) + + +# Build transform +if config.method == posteriors.laplace.diag_ggn: + transform = config.method.build(forward, outer_log_lik, **config.config_args) +else: + transform = config.method.build(log_posterior, **config.config_args) + +# Initialize state +state = transform.init(params) + +# Train +i = j = 0 +num_batches = len(train_dataloader) +log_dict = {k: [] for k in config.log_metrics.keys()} | {"loss": []} +log_bar = tqdm(total=0, position=1, bar_format="{desc}") +for epoch in range(args.epochs): + for batch in tqdm( + train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0 + ): + batch = tree_map(lambda x: x.to(args.device), batch) + state = transform.update(state, batch) + + # Update metrics + log_dict = utils.append_metrics(log_dict, state, config.log_metrics) + log_bar.set_description_str( + f"{config.display_metric}: {log_dict[config.display_metric][-1]:.2f}" + ) + + # Log + i += 1 + if i % config.log_frequency == 0 or i % num_batches == 0: + utils.log_metrics( + log_dict, + config.save_dir, + window=config.log_window, + file_name="training", + ) + + # Save sequential state if desired + if ( + config.save_frequency is not None + and (i - config.burnin) >= 0 + and (i - config.burnin) % config.save_frequency == 0 + ): + with open(f"{config.save_dir}/state_{j}.pkl", "wb") as f: + pickle.dump(state, f) + j += 1 + +# Save final state +with open(f"{config.save_dir}/state.pkl", "wb") as f: + pickle.dump(state, f) diff --git a/examples/imdb/train_runner.sh b/examples/imdb/train_runner.sh new file mode 100644 index 00000000..ecaaaeca --- /dev/null +++ b/examples/imdb/train_runner.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# List of temperature params +temperatures=(0.03 0.1 0.3 1.0 3.0) + + +device="cuda:1" + + +# config="examples/imdb/configs/vi.py" +# epochs=30 +# seeds=$(42) + +# config="examples/imdb/configs/sghmc_serial.py" +# epochs=60 +# seeds=$(42) + +config="examples/imdb/configs/sghmc_parallel.py" +epochs=30 +seeds=$(seq 2 19) + +for seed in $seeds +do + for temp in "${temperatures[@]}" + do + PYTHONPATH=. python examples/imdb/train.py --config $config --device $device \ + --temperature $temp --epochs $epochs --seed $seed + done +done \ No newline at end of file diff --git a/examples/imdb/utils.py b/examples/imdb/utils.py new file mode 100644 index 00000000..b5b14813 --- /dev/null +++ b/examples/imdb/utils.py @@ -0,0 +1,45 @@ +import json +import matplotlib.pyplot as plt +import numpy as np +import shutil + + +# Save config +def save_config(args, save_dir): + shutil.copy(args.config, save_dir + "/config.py") + json.dump(vars(args), open(f"{save_dir}/args.json", "w")) + + +# Function to calculate moving average and accompanying x-axis values +def moving_average(x, w): + w_check = 1 if w >= len(x) else w + y = np.convolve(x, np.ones(w_check), "valid") / w_check + return range(w_check - 1, len(x)), y + + +# Function to log and plot metrics +def log_metrics(log_dict, save_dir, window=1, file_name="metrics", plot=True): + save_file = f"{save_dir}/{file_name}.json" + + with open(save_file, "w") as f: + json.dump(log_dict, f) + + if plot: + for k, v in log_dict.items(): + fig, ax = plt.subplots() + x, y = moving_average(v, window) + ax.plot(x, y, label=k, alpha=0.7) + ax.legend() + ax.set_xlabel("Iteration") + fig.tight_layout() + fig.savefig(f"{save_dir}/{file_name}_{k}.png", dpi=200) + plt.close(fig) + + +# Function to append metrics to log_dict +def append_metrics(log_dict, state, config_dict): + for k, v in config_dict.items(): + log_dict[k].append(getattr(state, v).mean().item()) + + log_dict["loss"].append(state.aux[0].mean().item()) + return log_dict From 420262215715cde8ca9e440316f40df31c813802 Mon Sep 17 00:00:00 2001 From: SamDuffield Date: Sun, 19 May 2024 16:44:29 +0000 Subject: [PATCH 2/6] Plot with MLE --- examples/imdb/plot.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/examples/imdb/plot.py b/examples/imdb/plot.py index 45489e3c..ce8f26ce 100644 --- a/examples/imdb/plot.py +++ b/examples/imdb/plot.py @@ -11,9 +11,6 @@ temp_strs = [str(temp).replace(".", "-") for temp in temperatures] -ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)} - - # [base directory, tempered bool, colors] bases = [ @@ -36,6 +33,13 @@ # ["examples/imdb/results/sghmc_parallel/sghmc_parallel", True, ["tomato"]], # ] +with_mle = "mle" in bases[0][0] + +if with_mle: + ylims = {"loss": (0.305, 1.1), "accuracy": (0.17, 0.88)} +else: + ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)} + save_name = bases[-1][0].split("/")[-1].split("_")[0] @@ -95,7 +99,7 @@ def plot_metric(metric_name): fig, ax = plt.subplots() k = 0 - linewidth = 3 + linewidth = 2 if with_mle else 3 for method_name, method_dict in metric_dict.items(): if not isinstance(method_dict, dict): @@ -142,7 +146,14 @@ def plot_metric(metric_name): fontsize=fontsize, ) fig.tight_layout() - fig.savefig(f"examples/imdb/results/{save_name}_{metric_name}.png", dpi=400) + + save_dir = ( + f"examples/imdb/results/{save_name}_{metric_name}_with_mle.png" + if with_mle + else f"examples/imdb/results/{save_name}_{metric_name}.png" + ) + fig.savefig(save_dir, dpi=400) + plt.close() plot_metric("loss") From fd76761a85e61b75378b232c01fb95a41f71676c Mon Sep 17 00:00:00 2001 From: SamDuffield Date: Sun, 19 May 2024 17:44:55 +0000 Subject: [PATCH 3/6] Fontsize for MLE plot --- examples/imdb/plot.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/imdb/plot.py b/examples/imdb/plot.py index ce8f26ce..7a05e59b 100644 --- a/examples/imdb/plot.py +++ b/examples/imdb/plot.py @@ -12,7 +12,6 @@ # [base directory, tempered bool, colors] - bases = [ # ["examples/imdb/results/mle", False, ["grey"]], ["examples/imdb/results/map", False, ["grey"]], @@ -36,7 +35,7 @@ with_mle = "mle" in bases[0][0] if with_mle: - ylims = {"loss": (0.305, 1.1), "accuracy": (0.17, 0.88)} + ylims = {"loss": (0.305, 1.1), "accuracy": (0.47, 0.88)} else: ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)} @@ -138,12 +137,14 @@ def plot_metric(metric_name): ax.set_xlabel("Temperature", fontsize=fontsize) ax.set_ylabel("Test " + metric_name.title(), fontsize=fontsize) + + leg_fontsize = fontsize * 0.75 if with_mle else fontsize ax.legend( frameon=True, framealpha=1.0, facecolor="white", edgecolor="white", - fontsize=fontsize, + fontsize=leg_fontsize, ) fig.tight_layout() From 114e1d99c119941f92396373b0e242095cdb2ac9 Mon Sep 17 00:00:00 2001 From: SamDuffield Date: Tue, 21 May 2024 17:59:54 +0000 Subject: [PATCH 4/6] Update for more seeds --- examples/imdb/combine_states.py | 66 ------------ examples/imdb/combine_states_parallel.py | 65 ++++++++++++ examples/imdb/combine_states_serial.py | 64 ++++++++++++ examples/imdb/configs/laplace_fisher.py | 16 ++- examples/imdb/configs/laplace_ggn.py | 16 ++- examples/imdb/plot.py | 125 +++++++++++++++-------- examples/imdb/test.py | 17 ++- examples/imdb/test_runner.sh | 67 ++++++++++-- examples/imdb/train.py | 4 + examples/imdb/train_runner.sh | 36 +++++-- 10 files changed, 329 insertions(+), 147 deletions(-) delete mode 100644 examples/imdb/combine_states.py create mode 100644 examples/imdb/combine_states_parallel.py create mode 100644 examples/imdb/combine_states_serial.py diff --git a/examples/imdb/combine_states.py b/examples/imdb/combine_states.py deleted file mode 100644 index faa98b4e..00000000 --- a/examples/imdb/combine_states.py +++ /dev/null @@ -1,66 +0,0 @@ -import pickle -import torch -import os -from optree import tree_map -import shutil - - -temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] - - -# # SGHMC Serial -# base = "examples/imdb/results/sghmc_serial" -# seeds = [None] - - -# SGHMC Parallel -base = "examples/imdb/results/sghmc_parallel/sghmc_parallel" -seeds = list(range(1, 21)) - - -for temp in temperatures: - temp_str = str(temp).replace(".", "-") - - load_paths = [] - - for seed in seeds: - spec_base = base - - if seed is not None: - spec_base += f"_seed{seed}" - - spec_base += f"_temp{temp_str}/" - - match_str = ( - "state_" if seed is None else "state" - ) # don't need last save for serial - - load_paths += [ - spec_base + file for file in os.listdir(spec_base) if match_str in file - ] - - save_dir = base + f"_temp{temp_str}/" - - if not os.path.exists(save_dir): - os.makedirs(save_dir) - shutil.copy(spec_base + "config.py", save_dir + "/config.py") - - with open(f"{save_dir}/paths.txt", "w") as f: - f.write("\n".join(load_paths)) - - # Load states - states = [pickle.load(open(d, "rb")) for d in load_paths] - - # Delete auxiliary info - for s in states: - del s.aux - - # Move states to cpu - states = tree_map(lambda x: x.detach().to("cpu"), states) - - # Combine states - combined_state = tree_map(lambda *x: torch.stack(x), *states) - - # Save state - with open(f"{save_dir}state.pkl", "wb") as f: - pickle.dump(combined_state, f) diff --git a/examples/imdb/combine_states_parallel.py b/examples/imdb/combine_states_parallel.py new file mode 100644 index 00000000..29727e1e --- /dev/null +++ b/examples/imdb/combine_states_parallel.py @@ -0,0 +1,65 @@ +import pickle +import torch +import os +from optree import tree_map +import shutil + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] + + +# SGHMC Parallel +base = "examples/imdb/results/sghmc_parallel/sghmc_parallel" +save_dir_base = "examples/imdb/results/sghmc_parallel" + + +bootstrap_seeds = [ + (torch.multinomial(torch.ones(35), 15, replacement=False) + 1).tolist() + for _ in range(5) +] + + +for temp in temperatures: + temp_str = str(temp).replace(".", "-") + + load_paths = [] + + for k, seeds in enumerate(bootstrap_seeds): + for seed in seeds: + spec_base = base + spec_base += f"_seed{seed}" + spec_base += f"_temp{temp_str}/" + + match_str = ( + "state_" if seed is None else "state" + ) # don't need last save for serial + + load_paths += [ + spec_base + file for file in os.listdir(spec_base) if match_str in file + ] + + save_dir = save_dir_base + f"_seed{k+1}" + f"_temp{temp_str}/" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + shutil.copy(spec_base + "config.py", save_dir + "/config.py") + + with open(f"{save_dir}/paths.txt", "w") as f: + f.write("\n".join(load_paths)) + + # Load states + states = [pickle.load(open(d, "rb")) for d in load_paths] + + # Delete auxiliary info + for s in states: + del s.aux + + # Move states to cpu + states = tree_map(lambda x: x.detach().to("cpu"), states) + + # Combine states + combined_state = tree_map(lambda *x: torch.stack(x), *states) + + # Save state + with open(f"{save_dir}state.pkl", "wb") as f: + pickle.dump(combined_state, f) diff --git a/examples/imdb/combine_states_serial.py b/examples/imdb/combine_states_serial.py new file mode 100644 index 00000000..681ab23d --- /dev/null +++ b/examples/imdb/combine_states_serial.py @@ -0,0 +1,64 @@ +import pickle +import torch +import os +from optree import tree_map +import shutil + + +temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] + +# SGHMC Serial +bases = [ + f"examples/imdb/results/sghmc_serial_seed{repeat_seed}" + for repeat_seed in range(1, 6) +] +seeds = [None] + + +for base in bases: + for temp in temperatures: + temp_str = str(temp).replace(".", "-") + + load_paths = [] + + for seed in seeds: + spec_base = base + + if seed is not None: + spec_base += f"_seed{seed}" + + spec_base += f"_temp{temp_str}/" + + match_str = ( + "state_" if seed is None else "state" + ) # don't need last save for serial + + load_paths += [ + spec_base + file for file in os.listdir(spec_base) if match_str in file + ] + + save_dir = base + f"_temp{temp_str}/" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + shutil.copy(spec_base + "config.py", save_dir + "/config.py") + + with open(f"{save_dir}/paths.txt", "w") as f: + f.write("\n".join(load_paths)) + + # Load states + states = [pickle.load(open(d, "rb")) for d in load_paths] + + # Delete auxiliary info + for s in states: + del s.aux + + # Move states to cpu + states = tree_map(lambda x: x.detach().to("cpu"), states) + + # Combine states + combined_state = tree_map(lambda *x: torch.stack(x), *states) + + # Save state + with open(f"{save_dir}state.pkl", "wb") as f: + pickle.dump(combined_state, f) diff --git a/examples/imdb/configs/laplace_fisher.py b/examples/imdb/configs/laplace_fisher.py index 9fc95125..40d42e6c 100644 --- a/examples/imdb/configs/laplace_fisher.py +++ b/examples/imdb/configs/laplace_fisher.py @@ -2,11 +2,9 @@ import posteriors from optree import tree_map -temperature = 1.0 - name = "laplace_fisher" save_dir = "examples/imdb/results/" + name -params_dir = "examples/imdb/results/map/state.pkl" # directory to load state containing initialisation params +params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params prior_sd = torch.inf batch_size = 32 @@ -19,20 +17,20 @@ display_metric = "loss" # metric to display in tqdm progress bar log_frequency = 100 # frequency at which to log metrics -log_window = 40 # window size for moving average +log_window = 30 # window size for moving average n_test_samples = 50 n_linearised_test_samples = 10000 epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag -def to_sd_diag(state): +def to_sd_diag(state, temperature=1.0): return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) -def forward(model, state, batch): +def forward(model, state, batch, temperature=1.0): x, _ = batch - sd_diag = to_sd_diag(state) + sd_diag = to_sd_diag(state, temperature) sampled_params = posteriors.diag_normal_sample( state.params, sd_diag, (n_test_samples,) @@ -47,9 +45,9 @@ def model_func(p, x): return logits -def forward_linearized(model, state, batch): +def forward_linearized(model, state, batch, temperature=1.0): x, _ = batch - sd_diag = to_sd_diag(state) + sd_diag = to_sd_diag(state, temperature) def model_func_with_aux(p, x): return torch.func.functional_call(model, p, x), torch.tensor([]) diff --git a/examples/imdb/configs/laplace_ggn.py b/examples/imdb/configs/laplace_ggn.py index 92a8628b..f06760ea 100644 --- a/examples/imdb/configs/laplace_ggn.py +++ b/examples/imdb/configs/laplace_ggn.py @@ -2,11 +2,9 @@ import posteriors from optree import tree_map -temperature = 1.0 - name = "laplace_ggn" save_dir = "examples/imdb/results/" + name -params_dir = "examples/imdb/results/map/state.pkl" # directory to load state containing initialisation params +params_dir = "examples/imdb/results/map" # directory to load state containing initialisation params prior_sd = torch.inf batch_size = 32 @@ -19,20 +17,20 @@ display_metric = "loss" # metric to display in tqdm progress bar log_frequency = 100 # frequency at which to log metrics -log_window = 40 # window size for moving average +log_window = 30 # window size for moving average n_test_samples = 50 n_linearised_test_samples = 10000 epsilon = 1e-3 # small value to avoid division by zero in to_sd_diag -def to_sd_diag(state): +def to_sd_diag(state, temperature=1.0): return tree_map(lambda x: torch.sqrt(temperature / (x + epsilon)), state.prec_diag) -def forward(model, state, batch): +def forward(model, state, batch, temperature=1.0): x, _ = batch - sd_diag = to_sd_diag(state) + sd_diag = to_sd_diag(state, temperature) sampled_params = posteriors.diag_normal_sample( state.params, sd_diag, (n_test_samples,) @@ -47,9 +45,9 @@ def model_func(p, x): return logits -def forward_linearized(model, state, batch): +def forward_linearized(model, state, batch, temperature=1.0): x, _ = batch - sd_diag = to_sd_diag(state) + sd_diag = to_sd_diag(state, temperature) def model_func_with_aux(p, x): return torch.func.functional_call(model, p, x), torch.tensor([]) diff --git a/examples/imdb/plot.py b/examples/imdb/plot.py index 7a05e59b..1c8eeec1 100644 --- a/examples/imdb/plot.py +++ b/examples/imdb/plot.py @@ -9,6 +9,7 @@ temperatures = [0.03, 0.1, 0.3, 1.0, 3.0] temp_strs = [str(temp).replace(".", "-") for temp in temperatures] +seeds = list(range(1, 6)) # [base directory, tempered bool, colors] @@ -20,7 +21,7 @@ ] # bases = [ -# # ["examples/imdb/results/mle", False, ["grey"]], +# ["examples/imdb/results/mle", False, ["grey"]], # ["examples/imdb/results/map", False, ["grey"]], # ["examples/imdb/results/vi", True, ["forestgreen", "darkkhaki"]], # ] @@ -29,13 +30,13 @@ # # ["examples/imdb/results/mle", False, ["grey"]], # ["examples/imdb/results/map", False, ["grey"]], # ["examples/imdb/results/sghmc_serial", True, ["firebrick"]], -# ["examples/imdb/results/sghmc_parallel/sghmc_parallel", True, ["tomato"]], +# ["examples/imdb/results/sghmc_parallel", True, ["tomato"]], # ] with_mle = "mle" in bases[0][0] if with_mle: - ylims = {"loss": (0.305, 1.1), "accuracy": (0.47, 0.88)} + ylims = {"loss": (0.305, 1.2), "accuracy": (0.47, 0.88)} else: ylims = {"loss": (0.305, 0.72), "accuracy": (0.47, 0.88)} @@ -46,50 +47,66 @@ colour_dict = {} for base, tempered, colours in bases: - match_str = base + "_temp" + temp_strs[0] if tempered else base - versions = [file for file in os.listdir(match_str) if "test" in file] - versions = [v.strip(".json").split("_")[1] for v in versions] - versions.sort() - - for k, name in enumerate(versions): - colour_dict[name] = colours[k] - if tempered: - single_dict = { - temperatures[i]: json.load( - open(f"{base}_temp{temp_strs[i]}/test_{name}.json") - ) - for i in range(len(temperatures)) - } - else: - single_dict = {0.0: json.load(open(f"{base}/test_{name}.json"))} - - test_dicts |= {name: single_dict} + for seed in seeds: + seed_base = base + "_seed" + str(seed) if seed is not None else base + match_str = seed_base + "_temp" + temp_strs[0] if tempered else seed_base + versions = [file for file in os.listdir(match_str) if "test" in file] + versions = [v.strip(".json").split("_")[1] for v in versions] + versions.sort() + + for k, name in enumerate(versions): + colour_dict[name] = colours[k] + if tempered: + for temp in temperatures: + if temp not in test_dicts: + key = f"{name}_temp{str(temp).replace('.', '-')}_seed{seed}" + test_dicts[key] = json.load( + open( + f"{seed_base}_temp{str(temp).replace('.', '-')}/test_{name}.json" + ) + ) + else: + key = f"{name}_tempNA_seed{seed}" + test_dicts[key] = json.load(open(f"{seed_base}/test_{name}.json")) line_styles = ["--", ":", "-.", "-"] -def mean_dict(metric_name): +def sorted_set(l_in): + return sorted(set(l_in), key=l_in.index) + + +method_names = sorted_set([key.split("_")[0] for key in test_dicts.keys()]) + + +def metric_vals_dict(metric_name): md = {} - for method_name, method_dict in test_dicts.items(): - if len(method_dict) == 1: - val = np.mean(method_dict[0.0][metric_name], axis=0) - md |= {method_name: val} + for method_name in method_names: + method_name_keys = [ + key for key in test_dicts.keys() if key.split("_")[0] == method_name + ] - else: - plot_dict = { - temp: np.mean(vals[metric_name], axis=0) - for temp, vals in method_dict.items() - } - md |= {method_name: plot_dict} + temperature_keys = sorted_set([key.split("_")[1] for key in method_name_keys]) - return md + method_dict = {} + + for tk in temperature_keys: + temp = tk.strip("temp").replace("-", ".") + method_dict[temp] = [ + np.mean(test_dicts[key][metric_name]) + for key in method_name_keys + if key.split("_")[1] == tk + ] -metrics = list(list(list(test_dicts.values())[0].values())[0].keys()) + md |= {method_name: method_dict} + return md -metric_dicts = {metric_name: mean_dict(metric_name) for metric_name in metrics} + +metrics = ["loss", "accuracy"] +metric_dicts = {metric_name: metric_vals_dict(metric_name) for metric_name in metrics} def plot_metric(metric_name): @@ -101,30 +118,56 @@ def plot_metric(metric_name): linewidth = 2 if with_mle else 3 for method_name, method_dict in metric_dict.items(): - if not isinstance(method_dict, dict): + if len(method_dict) == 1: + mn = np.mean(method_dict["NA"]) + sds = np.std(method_dict["NA"]) + ax.fill_between( + [0.0, temperatures[-1] * 10], + [mn - sds] * 2, + [mn + sds] * 2, + color=colour_dict[method_name], + alpha=0.2, + zorder=0, + ) ax.axhline( - method_dict, + np.mean(method_dict["NA"]), label=method_name, c=colour_dict[method_name], linestyle=line_styles[k], linewidth=linewidth, + zorder=1, ) + k += 1 else: + mns = np.array([np.mean(method_dict[temp]) for temp in method_dict.keys()]) + sds = np.array([np.std(method_dict[temp]) for temp in method_dict.keys()]) + + ax.fill_between( + temperatures, + mns - sds, + mns + sds, + color=colour_dict[method_name], + alpha=0.2, + zorder=0, + ) ax.plot( - list(method_dict.keys()), - list(method_dict.values()), + temperatures, + mns, label=method_name, marker="o", markersize=10, c=colour_dict[method_name], linewidth=linewidth, + zorder=1, ) if metric_name in ylims: ax.set_ylim(ylims[metric_name]) + ax.set_xlim(temperatures[0] * 0.9, temperatures[-1] * 1.1) + ax.set_xscale("log") fontsize = 18 @@ -149,9 +192,9 @@ def plot_metric(metric_name): fig.tight_layout() save_dir = ( - f"examples/imdb/results/{save_name}_{metric_name}_with_mle.png" + f"examples/imdb/figures/{save_name}_{metric_name}_with_mle.png" if with_mle - else f"examples/imdb/results/{save_name}_{metric_name}.png" + else f"examples/imdb/figures/{save_name}_{metric_name}.png" ) fig.savefig(save_dir, dpi=400) plt.close() diff --git a/examples/imdb/test.py b/examples/imdb/test.py index df9cbcd6..65aa1d0c 100644 --- a/examples/imdb/test.py +++ b/examples/imdb/test.py @@ -16,6 +16,7 @@ parser.add_argument("--config", type=str) parser.add_argument("--device", default="cpu", type=str) parser.add_argument("--seed", default=42, type=int) +parser.add_argument("--temperature", default=None, type=float) args = parser.parse_args() @@ -24,6 +25,13 @@ # Import configuration config = importlib.import_module(args.config.replace("/", ".").replace(".py", "")) config_dir = os.path.dirname(args.config) +save_dir = ( + config_dir + if args.temperature is None + else config_dir + f"_temp{str(args.temperature).replace('.', '-')}" +) +if not os.path.exists(save_dir): + os.makedirs(save_dir) # Load data _, test_dataloader = load_imdb_dataset() @@ -82,7 +90,12 @@ def test_metrics(logits, labels): batch = tree_map(lambda x: x.to(args.device), batch) labels = batch[1] - forward_logits = {k: v(model, state, batch) for k, v in forward_dict.items()} + forward_logits = { + k: v(model, state, batch) + if args.temperature is None + else v(model, state, batch, args.temperature) + for k, v in forward_dict.items() + } forward_metrics = { k: test_metrics(logits, labels) for k, logits in forward_logits.items() @@ -98,7 +111,7 @@ def test_metrics(logits, labels): for forward_k, metric_dict in log_dict_forward.items(): log_metrics( metric_dict, - config_dir, + save_dir, file_name="test_" + forward_k, plot=False, ) diff --git a/examples/imdb/test_runner.sh b/examples/imdb/test_runner.sh index 4c216ed1..3fefc863 100644 --- a/examples/imdb/test_runner.sh +++ b/examples/imdb/test_runner.sh @@ -4,22 +4,67 @@ device="cuda:1" # List of temperature params temperatures=(0.03 0.1 0.3 1.0 3.0) +seeds=$(seq 1 5) # Base directories base_configs=( - "examples/imdb/results/laplace_fisher" - "examples/imdb/results/laplace_ggn" - "examples/imdb/results/vi" - "examples/imdb/results/sghmc_serial" - "examples/imdb/results/sghmc_parallel/sghmc_parallel" + # "examples/imdb/results/vi" + # "examples/imdb/results/sghmc_serial" + "examples/imdb/results/sghmc_parallel" ) - for temperature in "${temperatures[@]}"; do - for base_config in "${base_configs[@]}"; do - config_file="${base_config}_temp${temperature//./-}/config.py" - echo "Executing $config_file" - PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ - --device $device + for seed in $seeds; do + for base_config in "${base_configs[@]}"; do + config_file="${base_config}_seed${seed}_temp${temperature//./-}/config.py" + echo "Executing $config_file" + PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ + --device $device + done done done + + + +# # Laplace + +# temperatures=(0.03 0.1 0.3 1.0 3.0) +# seeds=$(seq 2) + + +# # Base directories +# base_configs=( +# "examples/imdb/results/laplace_fisher" +# # "examples/imdb/results/laplace_ggn" +# ) + +# for temperature in "${temperatures[@]}"; do +# for seed in $seeds; do +# for base_config in "${base_configs[@]}"; do +# config_file="${base_config}_seed${seed}/config.py" +# echo "Executing $config_file" +# PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ +# --device $device --temperature $temperature +# done +# done +# done + + +# # MLE and MAP +# # List of temperature params +# seeds=$(seq 1 5) + +# # Base directories +# base_configs=( +# "examples/imdb/results/mle" +# "examples/imdb/results/map" +# ) + +# for seed in $seeds; do +# for base_config in "${base_configs[@]}"; do +# config_file="${base_config}_seed${seed}/config.py" +# echo "Executing $config_file" +# PYTHONPATH=. python examples/imdb/test.py --config "$config_file" \ +# --device $device +# done +# done \ No newline at end of file diff --git a/examples/imdb/train.py b/examples/imdb/train.py index fe4e281d..1382ad85 100644 --- a/examples/imdb/train.py +++ b/examples/imdb/train.py @@ -29,8 +29,11 @@ # Set seed if args.seed != 42: config.save_dir += f"_seed{args.seed}" + if config.params_dir is not None: + config.params_dir += f"_seed{args.seed}/state.pkl" else: args.seed = 42 + config.params_dir += "/state.pkl" torch.manual_seed(args.seed) # Load model @@ -39,6 +42,7 @@ if config.params_dir is not None: with open(config.params_dir, "rb") as f: + print(config.params_dir) state = pickle.load(f) model.load_state_dict(state.params) diff --git a/examples/imdb/train_runner.sh b/examples/imdb/train_runner.sh index ecaaaeca..a9772f95 100644 --- a/examples/imdb/train_runner.sh +++ b/examples/imdb/train_runner.sh @@ -1,23 +1,41 @@ #!/bin/bash -# List of temperature params -temperatures=(0.03 0.1 0.3 1.0 3.0) - - -device="cuda:1" - +device="cuda:0" +# temperatures=(0.03 0.1 0.3 1.0 3.0) # config="examples/imdb/configs/vi.py" # epochs=30 -# seeds=$(42) +# seeds=$(seq 1 5) +# temperatures=(0.03 0.1 0.3 1.0 3.0) # config="examples/imdb/configs/sghmc_serial.py" # epochs=60 -# seeds=$(42) +# seeds=$(seq 1 5) +temperatures=(0.03 0.1 0.3 1.0 3.0) config="examples/imdb/configs/sghmc_parallel.py" epochs=30 -seeds=$(seq 2 19) +seeds=$(seq 1 35) + +# temperatures=(1.0) +# config="examples/imdb/configs/map.py" +# epochs=30 +# seeds=$(seq 1 5) + +# temperatures=(1.0) +# config="examples/imdb/configs/mle.py" +# epochs=30 +# seeds=$(seq 1 5) + +# temperatures=(1.0) +# config="examples/imdb/configs/laplace_fisher.py" +# epochs=1 +# seeds=$(seq 1 5) + +# temperatures=(1.0) +# config="examples/imdb/configs/laplace_ggn.py" +# epochs=1 +# seeds=$(seq 1 5) for seed in $seeds do From 54147950c55dffc6758a44121e558863a631bd59 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 22 May 2024 13:10:34 +0100 Subject: [PATCH 5/6] Add README --- examples/imdb/README.md | 119 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 examples/imdb/README.md diff --git a/examples/imdb/README.md b/examples/imdb/README.md new file mode 100644 index 00000000..783f5a40 --- /dev/null +++ b/examples/imdb/README.md @@ -0,0 +1,119 @@ +# Cold posterior effect on IMDB data + +We investigate a wide range of `posteriors` methods for a [CNN-LSTM](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) model on the [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). +We run each methods for a variety of temperatures, that is targetting the tempered +posterior distribution $p(\theta \mid \mathcal{D})^{\frac{1}{T}}$ for $T \geq 0$, thus +investigating the so-called [cold posterior effect](https://arxiv.org/abs/2008.05912), +where improved predictive performance has been observed for $T<1$. + +We observe improved significantly improved performance of Bayesian techniques over +gradient descent and notably that the cold posterior effect is significantly more +prominent for Gaussian approximations than SGMCMC variants. + +## Methods + +We train the CNN-LSTM model with 2.7m parameters on 25k reviews from the IMDB dataset +for binary classification of positive and negative reviews. We then use `posteriors` to +swap between the following methods: + +- **map**: Maximum a posteriori (MAP) estimate of the parameters. Simple optimization +using [AdamW](https://arxiv.org/abs/1711.05101) via `posteriors.torchopt`. +- **laplace.diag_fisher**: Use the MAP estimate as the basis for a Laplace approximation +using a diagonal empirical Fisher covariance matrix. Additionally use +`posteriors.linearized_forward_diag` to asses a [linearized version](https://arxiv.org/abs/2106.14806) +of the posterior predictive distribution. +- **laplace.diag_ggn**: Same as above but using the diagonal Gauss-Newton Fisher matrix, +which is [equivalent to the conditional Fisher information matrix](https://arxiv.org/abs/1412.1193) +(integrated over labels rather than using empirical distribution). Again we assess +traditional and linearized Laplace. +- **vi.diag**: Variational inference using a diagonal Gaussian approximation. Optimized +using [AdamW](https://arxiv.org/abs/1711.05101) and also linearized variant. +- **sgmcmc.sghmc (Serial)**: [Stochastic gradient Hamiltonian Monte Carlo](https://arxiv.org/abs/1506.04696) +(SGHMC) with a Monte Carlo approximation to the posterior collected by running a single +trajectory (and removing a burn-in). +- **sgmcmc.sghmc (Parallel)**: Same as above but running 15 trajectories in parallel and +only collecting the final state of each trajectory. + + +All methods are run for a variety of temperatures $T \in \{0.03, 0.1, 0.3, 1.0, 3.0\}$. +Each train is over 30 epochs except serial SGHMC which uses 60 epochs to collect 27 +samples from a single trajectory. Each method and temperature is run 5 times with +different random seeds, except for the parallel SGHMC in which we run over 35 seeds and +then bootstrap 5 ensembles of size 15. In all cases we use a diagonal Gaussian prior +with all variances set to 1/40. + + +## Results + +We plot the test loss for each method and temperature in Figure 1. + +

+ + + +
+ Figure 1. Test loss on IMDB data (lower is better) for varying temperatures. +

+ +There are a few takeaways from the results: +- Bayesian methods (VI and SGMCMC) can significantly improve over gradient descent (MAP +and we also trained for the MLE which severely overfits and was omitted from the plot +for clarity). +- Regular Laplace does not perform well, the linearization helps somewhat with GGN out +performing Empirical Fisher. +- In contrast, the linearization is detrimental for VI which we posit is due to the VI +training acting in parameter space without knowledge of linearization. +- A strong cold posterior effect for Gaussian methods (VI + Laplace), only very mild +cold posterior effect for non-Gaussian Bayes methods (SGHMC). +- Parallel SGHMC outperforms Serial SGHMC and also evidence that both out perform [deep +ensemble](https://arxiv.org/abs/1612.01474) (which is obtained with parallel SGHMC and +$T=0$). + + +## Data + +We download the IMDB dataset using [keras.datasets.imdb](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data). + + +## Model +We use the CNN-LSTM model from [Wenzel et al](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) +which consists of an embedding layer, a convolutional layer, ReLU activation, max-pooling layer, +LSTM layer, and a final dense layer. In total there are 2.7m parameters. We use a diagonal +Gaussian prior with all variances set to 1/40. + + +## Code structure + +- `lstm.py`: Simple, custom code for an LSTM layer that composes with `torch.func`. +- `model.py`: Specification of the CNN-LSTM model. +- `data.py`: Functions to load the IMDB data using `keras`. +- `train.py`: General script for training the model using `posteriors` methods + which can easily be swapped. +- `configs/` : Configuration files (python files) for the different methods. +- `train_runner.sh`: Bash script to run a series of training jobs. +- `combine_states_serial.py`: Combines parameter states from a single SGHMC trajectory +for ensembled forward calls. +- `combine_states_parallel.py`: Combines final parameter states from a multiple SGHMC +trajectories for ensembled forward calls. +- `test.py`: Calculate metrics on the IMDB test data for a given method. +- `test_runner.sh`: Bash script to run a series of testing jobs. +- `plot.py`: Generate Figure 1. +- `utils.py`: Small utility functions for logging. + +Training code for a single train can be run from the root directory: +```bash +PYTHONPATH=. python examples/imdb/train.py --config examples/imdb/configs/laplace_diag_fisher.py --temperature 0.1 --seed 0 --device cuda:0 +``` +or by configuring multiple runs in `train_runner.sh` and running: +```bash +bash examples/imdb/train_runner.sh +``` +Similarly testing code can be run from the root directory: +```bash +PYTHONPATH=. python examples/imdb/test.py --config examples/imdb/configs/laplace_diag_fisher.py --seed 0 --device cuda:0 +``` +or by configuring settings in `test_runner.sh` and running: +```bash +bash examples/imdb/test_runner.sh +``` + From 620668993d2ec3225b0e0b0d6dd756ce4b326424 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 22 May 2024 13:19:14 +0100 Subject: [PATCH 6/6] Update tutorials index --- docs/tutorials/index.md | 7 +++++++ examples/README.md | 3 +++ 2 files changed, 10 insertions(+) diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 9dc73de5..2bc52ee5 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -32,6 +32,13 @@ on a series of books from the [pg19](https://huggingface.co/datasets/pg19) datas Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch). +[](https://github.com/normal-computing/posteriors/tree/main/examples/imdb) + +- [`imdb`](https://github.com/normal-computing/posteriors/tree/main/examples/imdb): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) for a range of approximate +Bayesian methods with a CNN-LSTM model on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data) +data, with some interesting takeaways. + + [](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb) - [`continual_regression`](https://github.com/normal-computing/posteriors/blob/main/examples/continual_regression.ipynb): diff --git a/examples/README.md b/examples/README.md index 07c4b4c9..42ea63b3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,6 +4,9 @@ This directory contains examples of how to use the `posteriors` package. - [`continual_lora`](continual_lora/): Uses `posteriors.laplace.diag_fisher` to avoid catastrophic forgetting in fine-tuning [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) on a series of books from the [pg19](https://huggingface.co/datasets/pg19) dataset. +- [`imdb`](imdb/): Investigates [cold posterior effect](https://proceedings.mlr.press/v119/wenzel20a/wenzel20a.pdf) +for a range of approximate Bayesian methods on [IMDB](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb/load_data) +data. - [`yelp`](yelp/): Compares a host of `posteriors` methods (highlighting the easy exchangeability) on a sentiment analysis task adapted from the [Hugging Face tutorial](https://huggingface.co/docs/transformers/training#train-in-native-pytorch). - [`continual_regression`](continual_regression.ipynb): [Variational continual learning](https://arxiv.org/abs/1710.10628)