diff --git a/named-entity-recognition/run_ner.py b/named-entity-recognition/run_ner.py index c0a4232..5b2d809 100644 --- a/named-entity-recognition/run_ner.py +++ b/named-entity-recognition/run_ner.py @@ -246,7 +246,7 @@ def compute_metrics(p: EvalPrediction) -> Dict: trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) - if trainer.is_world_master(): + if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) # Evaluation @@ -257,7 +257,7 @@ def compute_metrics(p: EvalPrediction) -> Dict: result = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key, value in result.items(): @@ -284,7 +284,7 @@ def compute_metrics(p: EvalPrediction) -> Dict: # Save predictions output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_results_file, "w") as writer: logger.info("***** Test results *****") for key, value in metrics.items(): @@ -293,7 +293,7 @@ def compute_metrics(p: EvalPrediction) -> Dict: output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") - if trainer.is_world_master(): + if trainer.is_world_process_zero(): with open(output_test_predictions_file, "w") as writer: with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: example_id = 0 @@ -304,10 +304,11 @@ def compute_metrics(p: EvalPrediction) -> Dict: example_id += 1 elif preds_list[example_id]: entity_label = preds_list[example_id].pop(0) - if entity_label == 'O': - output_line = line.split()[0] + " " + entity_label + "\n" - else: - output_line = line.split()[0] + " " + entity_label[0] + "\n" + output_line = line.split()[0] + " " + entity_label + "\n" + #if entity_label == 'O': + # output_line = line.split()[0] + " " + entity_label + "\n" + #else: + # output_line = line.split()[0] + " " + entity_label[0] + "\n" # output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" writer.write(output_line) else: