Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
anmoisio committed Nov 20, 2023
1 parent a6b6019 commit 94f8119
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions src/genbench/tasks/europarl_dbca_splits/usage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@
https://huggingface.co/learn/nlp-course/chapter7/4?fw=pt
"""
import argparse

from datasets import DatasetDict
from transformers import (
DataCollatorForSeq2Seq,
FSMTConfig,
FSMTForConditionalGeneration,
FSMTTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
pipeline,
)

from genbench import load_task
from genbench.api import PreparationStrategy
from datasets import DatasetDict
from transformers import FSMTConfig, FSMTTokenizer, FSMTForConditionalGeneration, pipeline
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer


def tokenize_corpus(dataset, save_to_file):
"""
Tokenizes the dataset and saves it to disk.
"""

def preprocess_function(examples):
inputs = examples["input"]
targets = examples["target"]
model_inputs = tokenizer(
inputs, text_target=targets, max_length=MAX_LENGTH, truncation=True
)
model_inputs = tokenizer(inputs, text_target=targets, max_length=MAX_LENGTH, truncation=True)
return model_inputs

dataset = DatasetDict(dataset)
Expand Down Expand Up @@ -58,11 +66,7 @@ def train_from_scratch(tokenized_corpus, output_dir_name):
attention_dropout=0.1,
bos_token_id=0,
d_model=512,
decoder={
"bos_token_id": 2,
"model_type": "fsmt_decoder",
"vocab_size": 42024
},
decoder={"bos_token_id": 2, "model_type": "fsmt_decoder", "vocab_size": 42024},
decoder_attention_heads=8,
decoder_ffn_dim=2048,
decoder_layerdrop=0,
Expand Down Expand Up @@ -130,10 +134,10 @@ def train_from_scratch(tokenized_corpus, output_dir_name):
args = argparser.parse_args()

# Load the task
task = load_task('europarl_dbca_splits')
task = load_task("europarl_dbca_splits")

# A pretrained multilingual tokenizer, used for both models and both languages
tokenizer = FSMTTokenizer.from_pretrained('stas/tiny-wmt19-en-de')
tokenizer = FSMTTokenizer.from_pretrained("stas/tiny-wmt19-en-de")

MAX_LENGTH = 128
BATCH_SIZE = 128
Expand All @@ -151,7 +155,7 @@ def train_from_scratch(tokenized_corpus, output_dir_name):

subtask_dataset = subtask.get_prepared_datasets(PreparationStrategy.FINETUNING)

tokenized_dataset_dir = f'ds_de_comdiv{comdiv}_tokenized'
tokenized_dataset_dir = f"ds_de_comdiv{comdiv}_tokenized"
if args.tokenize:
tokenized_datasets = tokenize_corpus(subtask_dataset, tokenized_dataset_dir)
else:
Expand All @@ -162,27 +166,28 @@ def train_from_scratch(tokenized_corpus, output_dir_name):
tokenized_datasets["train"] = train_val_split["train"]
tokenized_datasets["validation"] = train_val_split["test"]

nmt_model_dir = f'FSMT_en-de_comdiv{comdiv}'
nmt_model_dir = f"FSMT_en-de_comdiv{comdiv}"
if args.train:
train_from_scratch(tokenized_datasets, nmt_model_dir)

if args.eval:
cp = 'checkpoint-100000'
cp = "checkpoint-100000"
print(f"Results for comdiv{comdiv}, checkpoint {cp}")
preds = translate_sentences(nmt_model_dir + '/' + cp,
tokenized_datasets["test"]["input"])
preds = translate_sentences(nmt_model_dir + "/" + cp, tokenized_datasets["test"]["input"])

# re-map the keys to match the evaluation script
preds = [{'target': pred['translation_text']} for pred in preds]
preds = [{"target": pred["translation_text"]} for pred in preds]

score = subtask.evaluate_predictions(
predictions=preds,
gold=tokenized_datasets["test"],
)
predictions=preds,
gold=tokenized_datasets["test"],
)
print(score)
results.append(score)

if args.eval:
print('Generalisation score (maximum compound divergence score divided by ' \
+ 'minimum compound divergence score):')
print(results[1]['hf_chrf__score'] / results[0]['hf_chrf__score'])
print(
"Generalisation score (maximum compound divergence score divided by "
+ "minimum compound divergence score):"
)
print(results[1]["hf_chrf__score"] / results[0]["hf_chrf__score"])

0 comments on commit 94f8119

Please sign in to comment.