forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_train.py
138 lines (116 loc) Β· 5.12 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import paddle
from paddle.static import InputSpec
from sklearn.metrics import f1_score
from utils import UTCLoss, read_local_dataset
from paddlenlp.datasets import load_dataset
from paddlenlp.prompt import (
PromptModelForSequenceClassification,
PromptTrainer,
PromptTuningArguments,
UTCTemplate,
)
from paddlenlp.trainer import PdArgumentParser
from paddlenlp.transformers import UTC, AutoTokenizer, export_model
@dataclass
class DataArguments:
dataset_path: str = field(
default="./data",
metadata={"help": "Local dataset directory including train.txt, dev.txt and label.txt (optional)."},
)
train_file: str = field(default="train.txt", metadata={"help": "Train dataset file name."})
dev_file: str = field(default="dev.txt", metadata={"help": "Dev dataset file name."})
threshold: float = field(default=0.5, metadata={"help": "The threshold to produce predictions."})
@dataclass
class ModelArguments:
model_name_or_path: str = field(
default="utc-large", metadata={"help": "The build-in pretrained UTC model name or path to its checkpoints."}
)
export_type: str = field(default="paddle", metadata={"help": "The type to export. Support `paddle` and `onnx`."})
export_model_dir: str = field(default="checkpoints/model_best", metadata={"help": "The export model path."})
def main():
# Parse the arguments.
parser = PdArgumentParser((ModelArguments, DataArguments, PromptTuningArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
paddle.set_device(training_args.device)
# Load the pretrained language model.
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
model = UTC.from_pretrained(model_args.model_name_or_path)
# Define template for preprocess and verbalizer for postprocess.
template = UTCTemplate(tokenizer, training_args.max_seq_length)
# Load and preprocess dataset.
train_ds = load_dataset(
read_local_dataset,
data_path=data_args.dataset_path,
data_file=data_args.train_file,
lazy=False,
)
dev_ds = load_dataset(
read_local_dataset,
data_path=data_args.dataset_path,
data_file=data_args.dev_file,
lazy=False,
)
# Define the criterion.
criterion = UTCLoss()
# Initialize the prompt model.
prompt_model = PromptModelForSequenceClassification(
model, template, None, freeze_plm=training_args.freeze_plm, freeze_dropout=training_args.freeze_dropout
)
# Define the metric function.
def compute_metrics(eval_preds):
labels = paddle.to_tensor(eval_preds.label_ids, dtype="int64")
preds = paddle.to_tensor(eval_preds.predictions)
preds = paddle.nn.functional.sigmoid(preds)
preds = preds[labels != -100].numpy()
labels = labels[labels != -100].numpy()
preds = preds > data_args.threshold
micro_f1 = f1_score(y_pred=preds, y_true=labels, average="micro")
macro_f1 = f1_score(y_pred=preds, y_true=labels, average="macro")
return {"micro_f1": micro_f1, "macro_f1": macro_f1}
trainer = PromptTrainer(
model=prompt_model,
tokenizer=tokenizer,
args=training_args,
criterion=criterion,
train_dataset=train_ds,
eval_dataset=dev_ds,
callbacks=None,
compute_metrics=compute_metrics,
)
# Training.
if training_args.do_train:
train_results = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
metrics = train_results.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Export.
if training_args.do_export:
input_spec = [
InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"),
InputSpec(shape=[None, None], dtype="int64", name="position_ids"),
InputSpec(shape=[None, None, None, None], dtype="float32", name="attention_mask"),
InputSpec(shape=[None, None], dtype="int64", name="omask_positions"),
InputSpec(shape=[None], dtype="int64", name="cls_positions"),
]
export_model(trainer.pretrained_model, input_spec, model_args.export_model_dir, model_args.export_type)
if __name__ == "__main__":
main()