Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
Add validation split, replace checkpoint name with bert-base
Browse files Browse the repository at this point in the history
  • Loading branch information
vernadankers committed Nov 20, 2023
1 parent ad31c99 commit 32b5f34
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/genbench/tasks/latent_feature_splits/usage_example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from datasets import load_dataset, DatasetDict
from datasets import load_dataset, DatasetDict, Dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, \
Trainer, TrainingArguments, AutoModelForSequenceClassification
import numpy as np
import evaluate
from genbench import load_task
from genbench.api import PreparationStrategy
import os

from sklearn.model_selection import train_test_split

def compute_metrics(eval_preds):
metric = evaluate.load("f1")
Expand Down Expand Up @@ -34,10 +34,14 @@ def tokenize_function(example):
return tokenizer(
example["input"])

# Convert GenBench format to HF dataset format, preview dataset
# Convert GenBench format to HF dataset format, get devset, preview dataset
task = load_task(f"latent_feature_splits:{split_name}")
ds = task.get_prepared_datasets(PreparationStrategy.FINETUNING)
ds = DatasetDict(ds)
ds_split = ds["train"].train_test_split(0.1)
ds = DatasetDict({
"train": ds_split["train"],
"validation": ds_split["test"],
"test": ds["test"]})
ds = ds.rename_column("target", "label")
print(ds)

Expand All @@ -61,7 +65,7 @@ def tokenize_function(example):
model,
training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
Expand All @@ -82,8 +86,8 @@ def tokenize_function(example):
split_name = "bert_closest_split"
num_labels = 3
batch_size = 16
lr = 3e-5
lr = 2e-5
epochs = 5
checkpoint = "prajjwal1/bert-small"
checkpoint = "bert-base-uncased"

main(split_name, num_labels, batch_size, lr, epochs, checkpoint)

0 comments on commit 32b5f34

Please sign in to comment.