Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(benchmarks) Add LLM evaluation pipeline for Medical challenge #3768

Merged
merged 37 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a56f2ee
Init medical eval
yan-gao-GY Jul 10, 2024
d1b6bfe
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Jul 10, 2024
2ba7a39
Update pyproject.toml
yan-gao-GY Jul 10, 2024
ea45e4b
Update readme
yan-gao-GY Jul 12, 2024
e7fbb62
Update readme
yan-gao-GY Jul 12, 2024
a2344ea
Update readme
yan-gao-GY Jul 15, 2024
d00479c
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Jul 15, 2024
7f40d3a
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
2327052
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
72db769
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
a06646b
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
5e9e6a7
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
b2529bf
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
f1c1804
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
a1fe792
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
87e1b29
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Aug 7, 2024
f63e83e
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Aug 7, 2024
7d58f1c
Merge overall readme
yan-gao-GY Aug 7, 2024
a885bde
Merge overall readme
yan-gao-GY Aug 7, 2024
6e9de01
update readme & evaluate.py
yan-gao-GY Aug 7, 2024
76785d6
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Aug 7, 2024
5638c5b
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Aug 8, 2024
9b50e89
Replace pyproject.toml with requirements.txt
yan-gao-GY Aug 8, 2024
3c2bb9e
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Aug 13, 2024
9f02f33
Update top readme
yan-gao-GY Aug 13, 2024
ef92bea
Remove useless import
yan-gao-GY Aug 15, 2024
85f0d94
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Aug 23, 2024
929b962
Update license
yan-gao-GY Aug 23, 2024
b520b06
Merge branch 'main' into add-llm-medical-eval
jafermarq Sep 2, 2024
978ad0f
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Sep 5, 2024
42763c8
Simplify code
yan-gao-GY Sep 6, 2024
8529b12
Formatting
yan-gao-GY Sep 6, 2024
a10542a
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Sep 6, 2024
822cf4d
Update benchmarks/flowertune-llm/evaluation/medical/README.md
yan-gao-GY Sep 9, 2024
2b72996
Merge branch 'main' into add-llm-medical-eval
yan-gao-GY Sep 9, 2024
4d1a39a
Merge branch 'main' into add-llm-medical-eval
jafermarq Sep 9, 2024
1b51125
Add ref for instruction
yan-gao-GY Sep 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions benchmarks/flowertune-llm/evaluation/medical/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Evaluation for Medical challenge

We build up a medical question answering (QA) pipeline to evaluate our fined-tuned LLMs.
Three datasets have been selected for this evaluation: [PubMedQA](https://huggingface.co/datasets/bigbio/pubmed_qa), [MedMCQA](https://huggingface.co/datasets/medmcqa), and [MedQA](https://huggingface.co/datasets/bigbio/med_qa).


## Environment Setup

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/benchmarks/flowertune-llm/evaluation/medical ./flowertune-eval-medical && rm -rf flower && cd flowertune-eval-medical
```

Create a new Python environment (we recommend Python 3.10), activate it, then install dependencies with:

```shell
# From a new python environment, run:
pip install -r requirements.txt

# Log in HuggingFace account
huggingface-cli login
```

## Generate model decision & calculate accuracy

```bash
python eval.py \
--peft-path=/path/to/fine-tuned-peft-model-dir/ \ # e.g., ./peft_1
--run-name=fl \ # specified name for this run
--batch-size=16 \
--quantization=4 \
--datasets=pubmedqa,medmcqa,medqa
```

The model answers and accuracy values will be saved to `benchmarks/generation_{dataset_name}_{run_name}.jsonl` and `benchmarks/acc_{dataset_name}_{run_name}.txt`, respectively.


> [!NOTE]
> Please ensure that you provide all **three accuracy values (PubMedQA, MedMCQA, MedQA)** for three evaluation datasets when submitting to the LLM Leaderboard (see the [`Make Submission`](https://github.com/adap/flower/tree/main/benchmarks/flowertune-llm/evaluation#make-submission-on-flowertune-llm-leaderboard) section).
174 changes: 174 additions & 0 deletions benchmarks/flowertune-llm/evaluation/medical/benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import json

import pandas as pd
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import format_answer, format_example, save_results

import datasets

# The instructions refer to Meditron evaluation:
# https://github.com/epfLLM/meditron/blob/main/evaluation/instructions.json
INSTRUCTIONS = {
"pubmedqa": "As an expert doctor in clinical science and medical knowledge, can you tell me if the following statement is correct? Answer yes, no, or maybe.",
"medqa": "You are a medical doctor taking the US Medical Licensing Examination. You need to demonstrate your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy. Show your ability to apply the knowledge essential for medical practice. For the following multiple-choice question, select one correct answer from A to E. Base your answer on the current and standard practices referenced in medical guidelines.",
"medmcqa": "You are a medical doctor answering realworld medical entrance exam questions. Based on your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy, answer the following multiple-choice question. Select one correct answer from A to D. Base your answer on the current and standard practices referenced in medical guidelines.",
}


def infer_pubmedqa(model, tokenizer, batch_size, run_name):
name = "pubmedqa"
answer_type = "boolean"
dataset = datasets.load_dataset(
"bigbio/pubmed_qa",
"pubmed_qa_labeled_fold0_source",
split="test",
trust_remote_code=True,
)
# Post process
instruction = INSTRUCTIONS[name]

def post_process(row):
context = "\n".join(row["CONTEXTS"])
row["prompt"] = f"{context}\n{row['QUESTION']}"
row["gold"] = row["final_decision"]
row["long_answer"] = row["LONG_ANSWER"]
row["prompt"] = f"{instruction}\n{row['prompt']}\nThe answer is:\n"
return row

dataset = dataset.map(post_process)

# Generate results
generate_results(name, run_name, dataset, model, tokenizer, batch_size, answer_type)


def infer_medqa(model, tokenizer, batch_size, run_name):
name = "medqa"
answer_type = "mcq"
dataset = datasets.load_dataset(
"bigbio/med_qa",
"med_qa_en_4options_source",
split="test",
trust_remote_code=True,
)

# Post process
instruction = INSTRUCTIONS[name]

def post_process(row):
choices = [opt["value"] for opt in row["options"]]
row["prompt"] = format_example(row["question"], choices)
for opt in row["options"]:
if opt["value"] == row["answer"]:
row["gold"] = opt["key"]
break
row["prompt"] = f"{instruction}\n{row['prompt']}\nThe answer is:\n"
return row

dataset = dataset.map(post_process)

# Generate results
generate_results(name, run_name, dataset, model, tokenizer, batch_size, answer_type)


def infer_medmcqa(model, tokenizer, batch_size, run_name):
name = "medmcqa"
answer_type = "mcq"
dataset = datasets.load_dataset(
"medmcqa", split="validation", trust_remote_code=True
)

# Post process
instruction = INSTRUCTIONS[name]

def post_process(row):
options = [row["opa"], row["opb"], row["opc"], row["opd"]]
answer = int(row["cop"])
row["prompt"] = format_example(row["question"], options)
row["gold"] = chr(ord("A") + answer) if answer in [0, 1, 2, 3] else None
row["prompt"] = f"{instruction}\n{row['prompt']}\nThe answer is:\n"
return row

dataset = dataset.map(post_process)

# Generate results
generate_results(name, run_name, dataset, model, tokenizer, batch_size, answer_type)


def generate_results(
name, run_name, dataset, model, tokenizer, batch_size, answer_type
):
# Run inference
prediction = inference(dataset, model, tokenizer, batch_size)

# Calculate accuracy
acc = accuracy_compute(prediction, answer_type)

# Save results and generations
save_results(name, run_name, prediction, acc)


def inference(dataset, model, tokenizer, batch_size):
columns_process = ["prompt", "gold"]
dataset_process = pd.DataFrame(dataset, columns=dataset.features)[columns_process]
dataset_process = dataset_process.assign(output="Null")
temperature = 1.0

inference_data = json.loads(dataset_process.to_json(orient="records"))
data_loader = DataLoader(inference_data, batch_size=batch_size, shuffle=False)

batch_counter = 0
for batch in tqdm(data_loader, total=len(data_loader), position=0, leave=True):
prompts = [
f"<|im_start|>question\n{prompt}<|im_end|>\n<|im_start|>answer\n"
for prompt in batch["prompt"]
]
if batch_counter == 0:
print(prompts[0])

# Process tokenizer
stop_seq = ["###"]
if tokenizer.eos_token is not None:
stop_seq.append(tokenizer.eos_token)
if tokenizer.pad_token is not None:
stop_seq.append(tokenizer.pad_token)
max_new_tokens = len(
tokenizer(batch["gold"][0], add_special_tokens=False)["input_ids"]
)

outputs = []
for prompt in prompts:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
output_ids = model.generate(
inputs=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
top_p=1.0,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
)
output_ids = output_ids[0][len(input_ids[0]) :]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
outputs.append(output)

for prompt, out in zip(batch["prompt"], outputs):
dataset_process.loc[dataset_process["prompt"] == prompt, "output"] = out
batch_counter += 1

return dataset_process


def accuracy_compute(dataset, answer_type):
dataset = json.loads(dataset.to_json(orient="records"))
preds, golds = [], []
for row in dataset:
answer = row["gold"].lower()
output = row["output"].lower()
pred, gold = format_answer(output, answer, answer_type=answer_type)
preds.append(pred)
golds.append(gold)

accuracy = accuracy_score(preds, golds)

return accuracy
62 changes: 62 additions & 0 deletions benchmarks/flowertune-llm/evaluation/medical/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import argparse

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from benchmarks import infer_medmcqa, infer_medqa, infer_pubmedqa

# Fixed seed
torch.manual_seed(2024)

parser = argparse.ArgumentParser()
parser.add_argument(
"--base-model-name-path", type=str, default="mistralai/Mistral-7B-v0.3"
)
parser.add_argument("--run-name", type=str, default="fl")
parser.add_argument("--peft-path", type=str, default=None)
parser.add_argument(
"--datasets",
type=str,
default="pubmedqa",
help="The dataset to infer on: [pubmedqa, medqa, medmcqa]",
)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--quantization", type=int, default=4)
args = parser.parse_args()


# Load model and tokenizer
if args.quantization == 4:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
torch_dtype = torch.float32
elif args.quantization == 8:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
torch_dtype = torch.float16
else:
raise ValueError(
f"Use 4-bit or 8-bit quantization. You passed: {args.quantization}/"
)

model = AutoModelForCausalLM.from_pretrained(
args.base_model_name_path,
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)
if args.peft_path is not None:
model = PeftModel.from_pretrained(
model, args.peft_path, torch_dtype=torch_dtype
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_path)

# Evaluate
for dataset in args.datasets.split(","):
if dataset == "pubmedqa":
infer_pubmedqa(model, tokenizer, args.batch_size, args.run_name)
elif dataset == "medqa":
infer_medqa(model, tokenizer, args.batch_size, args.run_name)
elif dataset == "medmcqa":
infer_medmcqa(model, tokenizer, args.batch_size, args.run_name)
else:
raise ValueError("Undefined Dataset.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
peft==0.6.2
pandas==2.2.2
scikit-learn==1.5.0
datasets==2.20.0
sentencepiece==0.2.0
protobuf==5.27.1
bitsandbytes==0.43.1
81 changes: 81 additions & 0 deletions benchmarks/flowertune-llm/evaluation/medical/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import re


def format_example(question, choices):
if not question.endswith("?") and not question.endswith("."):
question += "?"
options_str = "\n".join([f"{chr(65+i)}. {choices[i]}" for i in range(len(choices))])
prompt = "Question: " + question + "\n\nOptions:\n" + options_str
return prompt


def save_results(dataset_name, run_name, dataset, acc):
path = "./benchmarks/"
if not os.path.exists(path):
os.makedirs(path)

# Save results
results_path = os.path.join(path, f"acc_{dataset_name}_{run_name}.txt")
with open(results_path, "w") as f:
f.write(f"Accuracy: {acc}. ")
print(f"Accuracy: {acc}. ")

# Save generations
generation_path = os.path.join(path, f"generation_{dataset_name}_{run_name}.jsonl")
dataset.to_json(generation_path, orient="records")


def format_answer(output_full, answer, answer_type="mcq"):
output = output_full
default = (output_full, answer)
if "\n##" in output:
try:
output = output.split("\n##")[1].split("\n")[0].strip().lower()
except Exception:
return default
if "###" in answer:
try:
answer = answer.split("answer is:")[1].split("###")[0].strip()
except Exception:
return default

output = re.sub(r"[^a-zA-Z0-9]", " ", output).strip()
output = re.sub(" +", " ", output)

if answer_type == "boolean":
output = clean_boolean_answer(output)
elif answer_type == "mcq":
output = clean_mcq_answer(output)

if output in ["a", "b", "c", "d", "e", "yes", "no"]:
return output, answer
else:
return default


def clean_mcq_answer(output):
output = clean_answer(output)
try:
output = output[0]
except Exception:
return output
return output


def clean_boolean_answer(output):
if "yesyes" in output:
output = output.replace("yesyes", "yes")
elif "nono" in output:
output = output.replace("nono", "no")
elif "yesno" in output:
output = output.replace("yesno", "yes")
elif "noyes" in output:
output = output.replace("noyes", "no")
output = clean_answer(output)
return output


def clean_answer(output):
output_clean = output.encode("ascii", "ignore").decode("ascii")
return output_clean