diff --git a/src/genbench/tasks/latent_feature_splits/usage_example.py b/src/genbench/tasks/latent_feature_splits/usage_example.py index 4ab9ed1..6551f55 100644 --- a/src/genbench/tasks/latent_feature_splits/usage_example.py +++ b/src/genbench/tasks/latent_feature_splits/usage_example.py @@ -8,11 +8,6 @@ import os -def tokenize_function(example): - return tokenizer( - example["input"]) - - def compute_metrics(eval_preds): metric = evaluate.load("f1") logits, labels = eval_preds @@ -23,16 +18,22 @@ def compute_metrics(eval_preds): average="macro") -def main(split_name, num_labels, lr, epochs, checkpoint): +def main(split_name, num_labels, bsz, lr, epochs, checkpoint): """ Basic functionality to load data, train and evaluate the model. Args: - split_name: str (bert_closest_split | roberta_closest_split) - num_labels (int) + - bsz (int): batch size - lr (float): learning rate - epochs (int): number of epochs - checkpoint (str): should be a valid HF model name """ + + def tokenize_function(example): + return tokenizer( + example["input"]) + # Convert GenBench format to HF dataset format, preview dataset task = load_task(f"latent_feature_splits:{split_name}") ds = task.get_prepared_datasets(PreparationStrategy.FINETUNING) @@ -42,7 +43,8 @@ def main(split_name, num_labels, lr, epochs, checkpoint): # Load and preprocess data tokenizer = AutoTokenizer.from_pretrained(checkpoint) - tokenized_datasets = ds.map(tokenize_function, batched=True) + tokenized_datasets = ds.map( + tokenize_function, batch_size=bsz, batched=True) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Load model and HF trainer, WITH evaluation during training @@ -52,6 +54,8 @@ def main(split_name, num_labels, lr, epochs, checkpoint): "test-trainer", learning_rate=lr, num_train_epochs=epochs, + per_device_train_batch_size=bsz, + per_device_eval_batch_size=bsz, evaluation_strategy="epoch") trainer = Trainer( model, @@ -77,8 +81,9 @@ def main(split_name, num_labels, lr, epochs, checkpoint): os.environ["WANDB_DISABLED"] = "true" split_name = "bert_closest_split" num_labels = 3 + batch_size = 16 lr = 3e-5 epochs = 5 checkpoint = "prajjwal1/bert-small" - main(split_name, num_labels, lr, epochs, checkpoint) + main(split_name, num_labels, batch_size, lr, epochs, checkpoint)