Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
Add batch size to usage_example and move tokenize_function into main
Browse files Browse the repository at this point in the history
  • Loading branch information
vernadankers committed Nov 19, 2023
1 parent 86e123d commit ad31c99
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/genbench/tasks/latent_feature_splits/usage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
import os


def tokenize_function(example):
return tokenizer(
example["input"])


def compute_metrics(eval_preds):
metric = evaluate.load("f1")
logits, labels = eval_preds
Expand All @@ -23,16 +18,22 @@ def compute_metrics(eval_preds):
average="macro")


def main(split_name, num_labels, lr, epochs, checkpoint):
def main(split_name, num_labels, bsz, lr, epochs, checkpoint):
"""
Basic functionality to load data, train and evaluate the model.
Args:
- split_name: str (bert_closest_split | roberta_closest_split)
- num_labels (int)
- bsz (int): batch size
- lr (float): learning rate
- epochs (int): number of epochs
- checkpoint (str): should be a valid HF model name
"""

def tokenize_function(example):
return tokenizer(
example["input"])

# Convert GenBench format to HF dataset format, preview dataset
task = load_task(f"latent_feature_splits:{split_name}")
ds = task.get_prepared_datasets(PreparationStrategy.FINETUNING)
Expand All @@ -42,7 +43,8 @@ def main(split_name, num_labels, lr, epochs, checkpoint):

# Load and preprocess data
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenized_datasets = ds.map(tokenize_function, batched=True)
tokenized_datasets = ds.map(
tokenize_function, batch_size=bsz, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Load model and HF trainer, WITH evaluation during training
Expand All @@ -52,6 +54,8 @@ def main(split_name, num_labels, lr, epochs, checkpoint):
"test-trainer",
learning_rate=lr,
num_train_epochs=epochs,
per_device_train_batch_size=bsz,
per_device_eval_batch_size=bsz,
evaluation_strategy="epoch")
trainer = Trainer(
model,
Expand All @@ -77,8 +81,9 @@ def main(split_name, num_labels, lr, epochs, checkpoint):
os.environ["WANDB_DISABLED"] = "true"
split_name = "bert_closest_split"
num_labels = 3
batch_size = 16
lr = 3e-5
epochs = 5
checkpoint = "prajjwal1/bert-small"

main(split_name, num_labels, lr, epochs, checkpoint)
main(split_name, num_labels, batch_size, lr, epochs, checkpoint)

0 comments on commit ad31c99

Please sign in to comment.