From 32b5f34eb1236851683399a8f0cf2bae30dc3f8e Mon Sep 17 00:00:00 2001 From: Verna Date: Mon, 20 Nov 2023 16:43:12 +0000 Subject: [PATCH] Add validation split, replace checkpoint name with bert-base --- .../latent_feature_splits/usage_example.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/genbench/tasks/latent_feature_splits/usage_example.py b/src/genbench/tasks/latent_feature_splits/usage_example.py index 6551f55..b074a50 100644 --- a/src/genbench/tasks/latent_feature_splits/usage_example.py +++ b/src/genbench/tasks/latent_feature_splits/usage_example.py @@ -1,4 +1,4 @@ -from datasets import load_dataset, DatasetDict +from datasets import load_dataset, DatasetDict, Dataset from transformers import AutoTokenizer, DataCollatorWithPadding, \ Trainer, TrainingArguments, AutoModelForSequenceClassification import numpy as np @@ -6,7 +6,7 @@ from genbench import load_task from genbench.api import PreparationStrategy import os - +from sklearn.model_selection import train_test_split def compute_metrics(eval_preds): metric = evaluate.load("f1") @@ -34,10 +34,14 @@ def tokenize_function(example): return tokenizer( example["input"]) - # Convert GenBench format to HF dataset format, preview dataset + # Convert GenBench format to HF dataset format, get devset, preview dataset task = load_task(f"latent_feature_splits:{split_name}") ds = task.get_prepared_datasets(PreparationStrategy.FINETUNING) - ds = DatasetDict(ds) + ds_split = ds["train"].train_test_split(0.1) + ds = DatasetDict({ + "train": ds_split["train"], + "validation": ds_split["test"], + "test": ds["test"]}) ds = ds.rename_column("target", "label") print(ds) @@ -61,7 +65,7 @@ def tokenize_function(example): model, training_args, train_dataset=tokenized_datasets["train"], - eval_dataset=tokenized_datasets["test"], + eval_dataset=tokenized_datasets["validation"], data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics, @@ -82,8 +86,8 @@ def tokenize_function(example): split_name = "bert_closest_split" num_labels = 3 batch_size = 16 - lr = 3e-5 + lr = 2e-5 epochs = 5 - checkpoint = "prajjwal1/bert-small" + checkpoint = "bert-base-uncased" main(split_name, num_labels, batch_size, lr, epochs, checkpoint)