Skip to content

Commit

Permalink
new metric name
Browse files Browse the repository at this point in the history
  • Loading branch information
smarterliu committed Oct 17, 2024
1 parent 7e19048 commit 25d0ad0
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 26 deletions.
2 changes: 1 addition & 1 deletion eval_cot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ accelerate launch -m lm_eval \
--tasks rag_cot \
--batch_size 8 \
--num_fewshot 1 \
--output_path "./rag_res/" \
--output_path "./cot_res/" \
--log_samples
18 changes: 12 additions & 6 deletions lm_eval/tasks/hotpot_qa/hotpot_qa_cot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@ filter_list:
- function: remove_whitespace
- function: take_first
metric_list:
- metric: exact_match
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# - metric: !function utils.f1
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
- metric: exact_match_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: !function utils.f1
- metric: f1_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true

2 changes: 1 addition & 1 deletion lm_eval/tasks/hotpot_qa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def process_results(doc, results):
precision = 1.0 * num_same / len(completion_toks)
recall = 1.0 * num_same / len(ans_toks)
f1_score = (2 * precision * recall) / (precision + recall)
return {"exact_match": exact_score, "f1": f1_score}
return {"exact_match_score": exact_score, "f1_score": f1_score}


def f1(**kwargs):
Expand Down
26 changes: 16 additions & 10 deletions lm_eval/tasks/nq_open/nq_open_cot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,25 @@ filter_list:
- function: take_first
target_delimiter: " "
metric_list:
- metric: exact_match
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# regexes_to_ignore:
# - "\\b(?:The |the |An |A |The |a |an )"
# - metric: !function utils.f1
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# regexes_to_ignore:
# - "\ban|a|the\b"
- metric: exact_match_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )"
- metric: !function utils.f1
- metric: f1_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\ban|a|the\b"
metadata:
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/nq_open/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process_results(doc, results):
f1_score = (2 * precision * recall) / (precision + recall)
best_em_score = max(best_em_score, exact_score)
best_f1_score = max(best_f1_score, f1_score)
return {"exact_match": best_em_score, "f1": best_f1_score}
return {"exact_match_score": best_em_score, "f1_score": best_f1_score}

def f1(**kwargs):
references = kwargs["references"]
Expand Down
18 changes: 12 additions & 6 deletions lm_eval/tasks/triviaqa/triviaqa_cot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ filter_list:
- function: take_first
target_delimiter: " "
metric_list:
- metric: exact_match
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# - metric: !function utils.f1
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
- metric: exact_match_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: !function utils.f1
- metric: f1_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 3.0
2 changes: 1 addition & 1 deletion lm_eval/tasks/triviaqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def process_results(doc, results):
f1_score = (2 * precision * recall) / (precision + recall)
best_em_score = max(best_em_score, exact_score)
best_f1_score = max(best_f1_score, f1_score)
return {"exact_match": exact_score, "f1": f1_score}
return {"exact_match_score": exact_score, "f1_score": f1_score}

def f1(**kwargs):
references = kwargs["references"]
Expand Down

0 comments on commit 25d0ad0

Please sign in to comment.