Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bayesian Ensemble model for Llama3 (SGHMC) #93

Closed
wants to merge 63 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
7f82b22
Merge pull request #3 from normal-computing/sync
SamDuffield Apr 30, 2024
94dfd42
Idea for BayesLlama class
phoebeklett Apr 25, 2024
ea3dce7
Pare down code, vectorize
phoebeklett Apr 26, 2024
b6f6914
Add back hyperparams, make loading weights functional, use native fun…
phoebeklett Apr 29, 2024
7b1aa60
Make folder for experiment, generalize to all parameter options, add …
phoebeklett Apr 30, 2024
176f577
eliminate einops dependency
phoebeklett Apr 30, 2024
ce2dfd9
Update base classes so that bayes ensemble copies register as named p…
phoebeklett Apr 30, 2024
b54b009
skeleton for training script
phoebeklett Apr 30, 2024
840f34d
start setting up llama3 bayesian training code
johnathanchiu May 1, 2024
b68a5ee
add transform and state
johnathanchiu May 1, 2024
58a0910
get training running, need to debug OOM issues
johnathanchiu May 1, 2024
e3540d3
get running, need to fix issue with functional call not being done pr…
johnathanchiu May 1, 2024
edfd744
remove cropping from logits
johnathanchiu May 2, 2024
6139d3c
get code fully running, need to add logging
johnathanchiu May 2, 2024
684e37d
fix aux, log posterior outputs
johnathanchiu May 2, 2024
1d623d6
remove parameter type complexity
phoebeklett May 2, 2024
8fb80eb
save only trained weights
johnathanchiu May 2, 2024
cd7ab12
Merge branch 'llama3-training' of github.com:normal-computing/posteri…
johnathanchiu May 2, 2024
dad693b
save folders based on experiment config file name
johnathanchiu May 2, 2024
119f528
add params for alpha, beta, other hyperparams
johnathanchiu May 2, 2024
4b755f4
Update inference class to use state dict, update demo accoridingly
phoebeklett May 2, 2024
bf81f3b
bayesian layering with deepcopy and loading
johnathanchiu May 3, 2024
caf0217
remove unused load function
johnathanchiu May 3, 2024
1fbeff8
add logs to .gitignore
johnathanchiu May 3, 2024
f54b6a8
allowing setting temperature in config
johnathanchiu May 3, 2024
4e29406
add shuffling for data, add sequence length multiplier in temperature
johnathanchiu May 3, 2024
2e7e304
Update for generation step
phoebeklett May 3, 2024
b511d75
fix kv cache with ensemble layers
johnathanchiu May 6, 2024
2bbf0f1
Add eval script, helpers, example conf
phoebeklett May 7, 2024
0a1c573
Add runner and utils
phoebeklett May 7, 2024
d6a5e06
get eval scripts up and running, need to clean up
johnathanchiu May 8, 2024
a9e596c
clean up eval code
johnathanchiu May 8, 2024
0429217
allow specifying instruct vs. base model
johnathanchiu May 9, 2024
3740608
add validation set in training as well
johnathanchiu May 9, 2024
aab4f87
cleanup qa portions
johnathanchiu May 9, 2024
64e8321
update eval running scripts
johnathanchiu May 9, 2024
3ec026b
remove unused items, clean up utils
johnathanchiu May 9, 2024
67dc2bc
allow specifying for instruct model, notebook updates
johnathanchiu May 9, 2024
053e243
change configs for instruct models
johnathanchiu May 9, 2024
84b8c7a
fix dataloader to batch after sequences are tokenized
johnathanchiu May 9, 2024
dfeeef1
remove truncation param
johnathanchiu May 9, 2024
d9658df
debug issue with tokenizer, need to set padding to the left side
johnathanchiu May 9, 2024
a386595
add truncation back, it is necessary
johnathanchiu May 9, 2024
ddd6f42
fix eval to use instruct model weights, add eval code for the base mo…
johnathanchiu May 10, 2024
9083d15
add ensemble training script
johnathanchiu May 10, 2024
a80d599
fix seq length parameter input and grab from stride length
johnathanchiu May 10, 2024
87aa2ad
fix eval script to eval base model as well, eval on ensemble model
johnathanchiu May 10, 2024
48b5a74
setup to eval with head_qa
johnathanchiu May 13, 2024
455d7a4
remove experiments folder, fix eval folder missing param
johnathanchiu May 13, 2024
4c489d5
fix uncertainties calculation, reduce number of output tokens
johnathanchiu May 14, 2024
c7eef9d
move off gpu when calculating uncertainties
johnathanchiu May 14, 2024
6f2fec6
add base configs
May 14, 2024
bef6efe
add eval for base llama3
johnathanchiu May 14, 2024
a76d803
train on test data
johnathanchiu May 14, 2024
c6f08aa
include train, val, test in training for tqa
johnathanchiu May 15, 2024
87fd8d5
add ignore first n tokens
johnathanchiu May 16, 2024
65e5dfb
config for experiments, allow multiple runs of pure SGD
johnathanchiu May 20, 2024
a848010
add saving for last checkpoint
johnathanchiu May 20, 2024
491af26
clean up config files
johnathanchiu May 21, 2024
7e39ed7
save last checkpoint
johnathanchiu May 21, 2024
3f218f9
clean up eval code for statement completion
johnathanchiu May 21, 2024
d198280
add code for plotting experiment results
johnathanchiu May 21, 2024
16d5d92
remove notebooks for experiments
johnathanchiu May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ ENV/
.mypy_cache/

# Experiment runs
examples/runs/
examples/runs/
datasets
4 changes: 4 additions & 0 deletions examples/bayes_llama3/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Meta-Llama-3-8B
Meta-Llama-3-8B-Instruct
logs
experiments
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
experiment_config:
experiment_name: question-answer-test
chat_model: false
n_tokens: 5
tokenizer_pretrained_model_name_or_path: Meta-Llama-3-8B
pretrained_model_name_or_path: Meta-Llama-3-8B
checkpoints_folder: /home/paperspace/Projects/posteriors/examples/bayes_llama3/logs/ensemble_bayes/usable_checkpoints
7 changes: 7 additions & 0 deletions examples/bayes_llama3/configs/evaluation/eval_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
experiment_config:
eval_pretrained_model: true
experiment_name: question-answer-test
chat_model: false
n_tokens: 3
tokenizer_pretrained_model_name_or_path: Meta-Llama-3-8B
pretrained_model_name_or_path: Meta-Llama-3-8B
7 changes: 7 additions & 0 deletions examples/bayes_llama3/configs/evaluation/eval_sgd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
experiment_config:
experiment_name: question-answer-test
chat_model: false
n_tokens: 5
tokenizer_pretrained_model_name_or_path: Meta-Llama-3-8B
pretrained_model_name_or_path: Meta-Llama-3-8B
checkpoints_folder: /home/paperspace/Projects/posteriors/examples/bayes_llama3/logs/sgd_bayes_base/usable_checkpoints
23 changes: 23 additions & 0 deletions examples/bayes_llama3/configs/training/ensemble_bayes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
seed: 2024
epochs: 10

num_ensembles: 10

learning_rate: 1e-3
alpha: 1e-2
beta: 0.0
momenta: 0.0
set_temperature: True
pretrained_weights_folder: Meta-Llama-3-8B
ignore_first_tokens: 250

metrics_log_frequency: 5
save_frequency: 400

num_workers: 8
batch_size: 10
data_loader:
folder: "/home/paperspace/Projects/posteriors/datasets/tqa_train_val_test"
tokenizer_path: Meta-Llama-3-8B
stride_length: 300
stride_overlap: 100
22 changes: 22 additions & 0 deletions examples/bayes_llama3/configs/training/sgd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
seed: 2024
epochs: 50

learning_rate: 1e-3
alpha: 1e-2
beta: 0.0
momenta: 0.0
set_temperature: False
max_seq_len: 150
pretrained_weights_folder: Meta-Llama-3-8B
ignore_first_tokens: 250

metrics_log_frequency: 5
save_frequency: 1000

num_workers: 8
batch_size: 10
data_loader:
folder: "/home/paperspace/Projects/posteriors/datasets/tqa_train_val_test"
tokenizer_path: Meta-Llama-3-8B
stride_length: 300
stride_overlap: 100
310 changes: 310 additions & 0 deletions examples/bayes_llama3/llama3/data/statements.py

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions examples/bayes_llama3/llama3/data/tqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import os

from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl


class TQADataLoader(pl.LightningDataModule):
def __init__(self, data_config, num_workers=8, batch_size=10, shuffle=True):
super().__init__()
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.train_dataset = TQADataset(**data_config, split="train")

def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=self.shuffle,
)


class TQADataset(Dataset):
def __init__(
self,
folder,
tokenizer_path,
stride_length=300,
stride_overlap=100,
split="train",
):
assert split in {"train", "val", "test"}

self.stride_length = stride_length
self.stride_overlap = stride_overlap
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"

self.folder = os.path.join(folder, split)
with open(os.path.join(self.folder, f"tqa_v1_train.json")) as f:
qa_doc = json.load(f)

self.folder = os.path.join(folder, "val")
with open(os.path.join(self.folder, f"tqa_v1_val.json")) as f:
qa_doc += json.load(f)

self.folder = os.path.join(folder, "test")
with open(os.path.join(self.folder, f"tqa_v2_test.json")) as f:
qa_doc += json.load(f)

self.list_of_paragraphs = []
for concept in qa_doc:
for key in concept["topics"].keys():
self.list_of_paragraphs.append(
concept["topics"][key]["content"]["text"]
)

self.tokenized_dataset = []
for sample in self.list_of_paragraphs:
sample = self.tokenize_and_stride(sample)
batch_size = sample["input_ids"].size(0)
for idx in range(batch_size):
self.tokenized_dataset.append(
{
"input_ids": sample["input_ids"][idx],
"attention_mask": sample["attention_mask"][idx],
}
)

def tokenize_and_stride(self, sample):
return self.tokenizer(
sample,
truncation=True,
max_length=self.stride_length,
stride=self.stride_overlap,
padding="max_length",
return_tensors="pt",
)

def __getitem__(self, idx):
return self.tokenized_dataset[idx]

def __len__(self):
return len(self.tokenized_dataset)
193 changes: 193 additions & 0 deletions examples/bayes_llama3/llama3/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os

import torch
import torch.nn.functional as F
from ml_collections.config_dict import ConfigDict, FrozenConfigDict
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np

from llama3.data.statements import scientific_statements, scientific_statements_samoan
from llama3.modules.bayesllama import BayesLlamaForCausalLM
from llama3.utils.load_utils import load_ensemble


PROMPT = "Complete the following statement:\n"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 1
eps = 1e-5


def logits_to_uncertainties(eprobs):
probs = eprobs.mean(0)
total_uncertainty = -torch.sum(probs * torch.log(probs), -1)
aleatoric_uncertainty = -(eprobs * torch.log(eprobs)).sum(-1).mean(0)
epistemic_uncertainty = total_uncertainty - aleatoric_uncertainty
return total_uncertainty, epistemic_uncertainty


class EvaluationEngine:
def __init__(self, config: FrozenConfigDict):
self.config = ConfigDict(config)
self.n_tokens = config["n_tokens"]

self.tokenizer = AutoTokenizer.from_pretrained(
config["tokenizer_pretrained_model_name_or_path"]
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"

self.eval_pretrained = config.get("eval_pretrained_model", False)
if self.eval_pretrained:
self.model = AutoModelForCausalLM.from_pretrained(
config["pretrained_model_name_or_path"]
)
else:
assert os.path.isdir(
config["checkpoints_folder"]
), "Provided checkpoints is not a path to a folder"
checkpoints = [
os.path.join(config["checkpoints_folder"], path)
for path in os.listdir(config["checkpoints_folder"])
if path.endswith(".ckpt")
]
parameters = load_ensemble(checkpoints)

self.model = BayesLlamaForCausalLM.from_pretrained(
config["pretrained_model_name_or_path"],
bayes_config={"n_ensemble": len(checkpoints)},
)
self.model.load_bayesian_layers(parameters)

self.model.to(DEVICE)

@torch.no_grad()
def generate(self, inputs, gt_token, max_length=20, use_cache=True, is_base=False):
accuracy = 0.0
prediction_loss = []
epistemic_uncertainties = [[] for _ in range(inputs["input_ids"].size(0))]
total_uncertainties = [[] for _ in range(inputs["input_ids"].size(0))]

predictions = []
for token_idx in range(max_length):
outputs = self.model(**inputs, return_dict=False, use_cache=use_cache)

if "attention_mask" in inputs:
del inputs["attention_mask"]

if is_base:
elogits = outputs[0][:, -1].unsqueeze(0)
else:
elogits = outputs[0][:, :, -1]

eprobs = torch.softmax(elogits, dim=-1)
probs = eprobs.mean(0)

# (batch size, vocab_size)
pred_logits = torch.log(probs)
next_token = probs.argmax(-1).unsqueeze(-1)

if token_idx == 0:
accuracy += torch.where(next_token.squeeze(-1).cpu() == gt_token)[
0
].numel()
loss = F.cross_entropy(pred_logits.cpu(), gt_token)
prediction_loss.append(loss.item())

total_unc, epi_unc = logits_to_uncertainties(eprobs.cpu())
for idx, (tunc, eunc) in enumerate(zip(total_unc, epi_unc)):
total_uncertainties[idx].append(tunc.item())
epistemic_uncertainties[idx].append(eunc.item())

if use_cache:
inputs["past_key_values"] = outputs[1]
if not is_base:
inputs["ensemble_past_key_values"] = outputs[2]
inputs["input_ids"] = next_token
else:
inputs["input_ids"] = torch.cat(
[inputs["input_ids"], next_token], dim=1
)
predictions.append(next_token)

predictions = torch.cat(predictions, -1)
text_outputs = self.tokenizer.batch_decode(
predictions, skip_special_tokens=True
)
return (
text_outputs,
total_uncertainties,
epistemic_uncertainties,
prediction_loss,
accuracy,
)

def run_eval(self, statements, batch_size=3, n_tokens=5):
model_outputs = []
accuracies = 0.0
uncertainties = {"epistemic": [], "total": []}
loss_metrics = []
for i in range(0, len(statements), batch_size):
last_words = [
" " + s.split(" ")[-1] for s in statements[i : i + batch_size]
]
batch_statements = [
" ".join(s.split(" ")[:-1]) for s in statements[i : i + batch_size]
]

self.tokenizer.padding_side = "right"
last_token = self.tokenizer(last_words, padding=True, return_tensors="pt")[
"input_ids"
][:, 1]
self.tokenizer.padding_side = "left"

inputs = self.tokenizer(
[PROMPT + s + " " for s in batch_statements],
padding=True,
return_tensors="pt",
)
(
text_outputs,
total_uncertainty,
epistemic_uncertainty,
batch_loss,
batch_accuracy,
) = self.generate(
inputs.to("cuda"),
last_token,
max_length=n_tokens,
use_cache=True,
is_base=self.eval_pretrained,
)
loss_metrics.extend(batch_loss)
accuracies += batch_accuracy

for text_output, total_unc, epi_unc in zip(
text_outputs, total_uncertainty, epistemic_uncertainty
):
model_outputs.append(text_output)
uncertainties["total"].append(total_unc)
uncertainties["epistemic"].append(epi_unc)

return (
model_outputs,
uncertainties,
np.average(loss_metrics),
accuracies / len(statements),
)

def run(self):
statements = [
(scientific_statements, "en"),
(scientific_statements_samoan, "sa"),
]

results = {}
for statement_list, lang in statements:
outputs, uncertainties, loss, acc = self.run_eval(statement_list)
results[lang] = {
"outputs": outputs,
"uncertainties": uncertainties,
"loss": loss,
"acc": acc,
}
Loading
Loading