Skip to content

Commit ddccc68

Browse files
authored
【bug fix】Text classification application&example (PaddlePaddle#5070)
* fix version problem
1 parent 1363448 commit ddccc68

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

Diff for: applications/text_classification/multi_class/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ checkpoint/prune/
339339
<a name="模型预测"></a>
340340

341341
### 2.5 模型预测
342-
我们推荐使用taskflow进行模型预测。
342+
我们推荐使用taskflow进行模型预测,请保证paddlenlp版本大于2.5.1
343343
```
344344
from paddlenlp import Taskflow
345345

Diff for: applications/text_classification/multi_class/train.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,21 @@ def main():
102102
training_args.print_config(data_args, "Data")
103103
paddle.set_device(training_args.device)
104104

105+
# Define id2label
106+
id2label = {}
107+
label2id = {}
108+
with open(data_args.label_path, "r", encoding="utf-8") as f:
109+
for i, line in enumerate(f):
110+
l = line.strip()
111+
id2label[i] = l
112+
label2id[l] = i
113+
105114
# Define model & tokenizer
106115
if os.path.isdir(model_args.model_name_or_path):
107-
model = AutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path)
108-
id2label = model.id2label
109-
label2id = model.label2id
116+
model = AutoModelForSequenceClassification.from_pretrained(
117+
model_args.model_name_or_path, label2id=label2id, id2label=id2label
118+
)
110119
elif model_args.model_name_or_path in SUPPORTED_MODELS:
111-
id2label = {}
112-
label2id = {}
113-
with open(data_args.label_path, "r", encoding="utf-8") as f:
114-
for i, line in enumerate(f):
115-
l = line.strip()
116-
id2label[i] = l
117-
label2id[l] = i
118120
model = AutoModelForSequenceClassification.from_pretrained(
119121
model_args.model_name_or_path, num_classes=len(label2id), label2id=label2id, id2label=id2label
120122
)
@@ -186,7 +188,7 @@ def compute_metrics_debug(eval_preds):
186188
if training_args.do_eval:
187189
if data_args.debug:
188190
output = trainer.predict(test_ds)
189-
log_metrics_debug(output, id2label, dev_ds, data_args.bad_case_path)
191+
log_metrics_debug(output, id2label, test_ds, data_args.bad_case_path)
190192
else:
191193
eval_metrics = trainer.evaluate()
192194
trainer.log_metrics("eval", eval_metrics)

Diff for: examples/text_classification/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212

1313
## ERNIE-Doc Text Classification
1414

15-
[ERNIE-Doc Text Classification](./ernie-doc) 展示了如何使用预训练模型ERNIE-Doc完成**超长文本**分类任务。
15+
[ERNIE-Doc Text Classification](./ernie_doc) 展示了如何使用预训练模型ERNIE-Doc完成**超长文本**分类任务。

0 commit comments

Comments
 (0)