-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_classifier.py
72 lines (56 loc) · 2.19 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import sys
import os
import pickle
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import DataCollatorWithPadding
import evaluate
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
tokenizer = AutoTokenizer.from_pretrained('climatebert/distilroberta-base-climate-f')
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True)
accuracy = evaluate.load('accuracy')
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
def main(metadata_path):
metadata_filename = os.path.basename(metadata_path)
metadata_basename, _ = os.path.splitext(metadata_filename)
traindata_path = os.path.join('traindata', '{}.pkl'.format(metadata_basename))
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
with open(traindata_path, 'rb') as f:
id2label, label2id, text_data = pickle.load(f)
text_dataset = Dataset.from_dict(text_data).class_encode_column("label").train_test_split(
test_size=0.3,
stratify_by_column="label",
shuffle=True,
)
tokenized_data = text_dataset.map(preprocess_function, batched=True)
model = AutoModelForSequenceClassification.from_pretrained(
'climatebert/distilroberta-base-climate-f', num_labels=len(id2label.keys()), id2label=id2label, label2id=label2id
)
training_args = TrainingArguments(
output_dir='models/{}'.format(metadata_basename),
learning_rate=1e-4,
per_device_train_batch_size=36,
per_device_eval_batch_size=36,
num_train_epochs=10,
weight_decay=0.01,
evaluation_strategy='epoch',
save_strategy='epoch',
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_data['train'],
eval_dataset=tokenized_data['test'],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
if __name__ == '__main__':
main(sys.argv[1])