From 25d0ad02ad03db5eaeeb54c9e7773b8e360fce9a Mon Sep 17 00:00:00 2001 From: smarterliu Date: Thu, 17 Oct 2024 14:18:28 +0800 Subject: [PATCH] new metric name --- eval_cot.sh | 2 +- lm_eval/tasks/hotpot_qa/hotpot_qa_cot.yaml | 18 ++++++++++----- lm_eval/tasks/hotpot_qa/utils.py | 2 +- lm_eval/tasks/nq_open/nq_open_cot.yaml | 26 +++++++++++++--------- lm_eval/tasks/nq_open/utils.py | 2 +- lm_eval/tasks/triviaqa/triviaqa_cot.yaml | 18 ++++++++++----- lm_eval/tasks/triviaqa/utils.py | 2 +- 7 files changed, 44 insertions(+), 26 deletions(-) diff --git a/eval_cot.sh b/eval_cot.sh index 1fea4bd3..ad4882e1 100644 --- a/eval_cot.sh +++ b/eval_cot.sh @@ -16,5 +16,5 @@ accelerate launch -m lm_eval \ --tasks rag_cot \ --batch_size 8 \ --num_fewshot 1 \ - --output_path "./rag_res/" \ + --output_path "./cot_res/" \ --log_samples diff --git a/lm_eval/tasks/hotpot_qa/hotpot_qa_cot.yaml b/lm_eval/tasks/hotpot_qa/hotpot_qa_cot.yaml index 45939310..88b244e9 100644 --- a/lm_eval/tasks/hotpot_qa/hotpot_qa_cot.yaml +++ b/lm_eval/tasks/hotpot_qa/hotpot_qa_cot.yaml @@ -32,14 +32,20 @@ filter_list: - function: remove_whitespace - function: take_first metric_list: - - metric: exact_match + # - metric: exact_match + # aggregation: mean + # higher_is_better: true + # ignore_case: true + # ignore_punctuation: true + # - metric: !function utils.f1 + # aggregation: mean + # higher_is_better: true + # ignore_case: true + # ignore_punctuation: true + - metric: exact_match_score aggregation: mean higher_is_better: true - ignore_case: true - ignore_punctuation: true - - metric: !function utils.f1 + - metric: f1_score aggregation: mean higher_is_better: true - ignore_case: true - ignore_punctuation: true diff --git a/lm_eval/tasks/hotpot_qa/utils.py b/lm_eval/tasks/hotpot_qa/utils.py index 235f7a6b..7c7918b9 100644 --- a/lm_eval/tasks/hotpot_qa/utils.py +++ b/lm_eval/tasks/hotpot_qa/utils.py @@ -92,7 +92,7 @@ def process_results(doc, results): precision = 1.0 * num_same / len(completion_toks) recall = 1.0 * num_same / len(ans_toks) f1_score = (2 * precision * recall) / (precision + recall) - return {"exact_match": exact_score, "f1": f1_score} + return {"exact_match_score": exact_score, "f1_score": f1_score} def f1(**kwargs): diff --git a/lm_eval/tasks/nq_open/nq_open_cot.yaml b/lm_eval/tasks/nq_open/nq_open_cot.yaml index 49dda03a..f1adae62 100644 --- a/lm_eval/tasks/nq_open/nq_open_cot.yaml +++ b/lm_eval/tasks/nq_open/nq_open_cot.yaml @@ -27,19 +27,25 @@ filter_list: - function: take_first target_delimiter: " " metric_list: - - metric: exact_match + # - metric: exact_match + # aggregation: mean + # higher_is_better: true + # ignore_case: true + # ignore_punctuation: true + # regexes_to_ignore: + # - "\\b(?:The |the |An |A |The |a |an )" + # - metric: !function utils.f1 + # aggregation: mean + # higher_is_better: true + # ignore_case: true + # ignore_punctuation: true + # regexes_to_ignore: + # - "\ban|a|the\b" + - metric: exact_match_score aggregation: mean higher_is_better: true - ignore_case: true - ignore_punctuation: true - regexes_to_ignore: - - "\\b(?:The |the |An |A |The |a |an )" - - metric: !function utils.f1 + - metric: f1_score aggregation: mean higher_is_better: true - ignore_case: true - ignore_punctuation: true - regexes_to_ignore: - - "\ban|a|the\b" metadata: version: 3.0 diff --git a/lm_eval/tasks/nq_open/utils.py b/lm_eval/tasks/nq_open/utils.py index 1c5817aa..74593139 100644 --- a/lm_eval/tasks/nq_open/utils.py +++ b/lm_eval/tasks/nq_open/utils.py @@ -88,7 +88,7 @@ def process_results(doc, results): f1_score = (2 * precision * recall) / (precision + recall) best_em_score = max(best_em_score, exact_score) best_f1_score = max(best_f1_score, f1_score) - return {"exact_match": best_em_score, "f1": best_f1_score} + return {"exact_match_score": best_em_score, "f1_score": best_f1_score} def f1(**kwargs): references = kwargs["references"] diff --git a/lm_eval/tasks/triviaqa/triviaqa_cot.yaml b/lm_eval/tasks/triviaqa/triviaqa_cot.yaml index e603fd52..e1437d69 100644 --- a/lm_eval/tasks/triviaqa/triviaqa_cot.yaml +++ b/lm_eval/tasks/triviaqa/triviaqa_cot.yaml @@ -28,15 +28,21 @@ filter_list: - function: take_first target_delimiter: " " metric_list: - - metric: exact_match + # - metric: exact_match + # aggregation: mean + # higher_is_better: true + # ignore_case: true + # ignore_punctuation: true + # - metric: !function utils.f1 + # aggregation: mean + # higher_is_better: true + # ignore_case: true + # ignore_punctuation: true + - metric: exact_match_score aggregation: mean higher_is_better: true - ignore_case: true - ignore_punctuation: true - - metric: !function utils.f1 + - metric: f1_score aggregation: mean higher_is_better: true - ignore_case: true - ignore_punctuation: true metadata: version: 3.0 diff --git a/lm_eval/tasks/triviaqa/utils.py b/lm_eval/tasks/triviaqa/utils.py index 09e080f2..a5cc0926 100644 --- a/lm_eval/tasks/triviaqa/utils.py +++ b/lm_eval/tasks/triviaqa/utils.py @@ -84,7 +84,7 @@ def process_results(doc, results): f1_score = (2 * precision * recall) / (precision + recall) best_em_score = max(best_em_score, exact_score) best_f1_score = max(best_f1_score, f1_score) - return {"exact_match": exact_score, "f1": f1_score} + return {"exact_match_score": exact_score, "f1_score": f1_score} def f1(**kwargs): references = kwargs["references"]