@@ -102,19 +102,21 @@ def main():
102
102
training_args .print_config (data_args , "Data" )
103
103
paddle .set_device (training_args .device )
104
104
105
+ # Define id2label
106
+ id2label = {}
107
+ label2id = {}
108
+ with open (data_args .label_path , "r" , encoding = "utf-8" ) as f :
109
+ for i , line in enumerate (f ):
110
+ l = line .strip ()
111
+ id2label [i ] = l
112
+ label2id [l ] = i
113
+
105
114
# Define model & tokenizer
106
115
if os .path .isdir (model_args .model_name_or_path ):
107
- model = AutoModelForSequenceClassification .from_pretrained (model_args . model_name_or_path )
108
- id2label = model . id2label
109
- label2id = model . label2id
116
+ model = AutoModelForSequenceClassification .from_pretrained (
117
+ model_args . model_name_or_path , label2id = label2id , id2label = id2label
118
+ )
110
119
elif model_args .model_name_or_path in SUPPORTED_MODELS :
111
- id2label = {}
112
- label2id = {}
113
- with open (data_args .label_path , "r" , encoding = "utf-8" ) as f :
114
- for i , line in enumerate (f ):
115
- l = line .strip ()
116
- id2label [i ] = l
117
- label2id [l ] = i
118
120
model = AutoModelForSequenceClassification .from_pretrained (
119
121
model_args .model_name_or_path , num_classes = len (label2id ), label2id = label2id , id2label = id2label
120
122
)
@@ -186,7 +188,7 @@ def compute_metrics_debug(eval_preds):
186
188
if training_args .do_eval :
187
189
if data_args .debug :
188
190
output = trainer .predict (test_ds )
189
- log_metrics_debug (output , id2label , dev_ds , data_args .bad_case_path )
191
+ log_metrics_debug (output , id2label , test_ds , data_args .bad_case_path )
190
192
else :
191
193
eval_metrics = trainer .evaluate ()
192
194
trainer .log_metrics ("eval" , eval_metrics )
0 commit comments