Skip to content

Commit

Permalink
add cot eval
Browse files Browse the repository at this point in the history
  • Loading branch information
smarterliu committed Oct 16, 2024
1 parent ab6dff4 commit 7f62e6e
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 0 deletions.
18 changes: 18 additions & 0 deletions lm_eval/tasks/hotpot_qa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random

f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
Expand Down Expand Up @@ -70,6 +71,23 @@ def _extract_facts(context):
"answer": examples["answer"].strip(),
"facts": facts.strip(),
}

def process_results(doc, results):
completion = results[0]
ans = doc["answer"]
exact_score = exact_match(references=[ans], predictions=[completion])
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
num_same = sum(common.values())
if len(ans_toks) == 0 or len(completion_toks) == 0:
f1_score = 0
else:
precision = 1.0 * num_same / len(completion_toks)
recall = 1.0 * num_same / len(ans_toks)
f1_score = (2 * precision * recall) / (precision + recall)
return {"exact_match": exact_score, "f1": f1_score}


def f1(**kwargs):
references = kwargs["references"]
Expand Down
17 changes: 17 additions & 0 deletions lm_eval/tasks/nq_open/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rouge_score import rouge_scorer, scoring

f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
Expand Down Expand Up @@ -61,6 +62,22 @@ def _extract_facts(docs):
"answer": examples["answer"],
"facts": facts.strip(),
}

def process_results(doc, results):
completion = results[0]
ans = doc["answer"]
exact_score = exact_match(references=[ans], predictions=[completion])
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
num_same = sum(common.values())
if len(ans_toks) == 0 or len(completion_toks) == 0:
f1_score = 0
else:
precision = 1.0 * num_same / len(completion_toks)
recall = 1.0 * num_same / len(ans_toks)
f1_score = (2 * precision * recall) / (precision + recall)
return {"exact_match": exact_score, "f1": f1_score}

def f1(**kwargs):
references = kwargs["references"]
Expand Down
17 changes: 17 additions & 0 deletions lm_eval/tasks/triviaqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rouge_score import rouge_scorer, scoring

f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
Expand Down Expand Up @@ -57,6 +58,22 @@ def _extract_facts(docs):
"answer": examples["answer"],
"facts": facts.strip(),
}

def process_results(doc, results):
completion = results[0]
ans = doc["answer"]
exact_score = exact_match(references=[ans], predictions=[completion])
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
num_same = sum(common.values())
if len(ans_toks) == 0 or len(completion_toks) == 0:
f1_score = 0
else:
precision = 1.0 * num_same / len(completion_toks)
recall = 1.0 * num_same / len(ans_toks)
f1_score = (2 * precision * recall) / (precision + recall)
return {"exact_match": exact_score, "f1": f1_score}

def f1(**kwargs):
references = kwargs["references"]
Expand Down
1 change: 1 addition & 0 deletions lm_eval/tasks/truthfulqa/truthfulqa_gen_search.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
group:
- truthfulqa
- rag_search_test
- rag_cot
task: truthfulqa_gen_with_search_results
dataset_path: truthful_qa
dataset_name: generation
Expand Down
1 change: 1 addition & 0 deletions lm_eval/tasks/truthfulqa/truthfulqa_mc1_search.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
group:
- truthfulqa
- rag_search_test
- rag_cot
task: truthfulqa_mc1_with_search_results
dataset_path: truthful_qa
dataset_name: multiple_choice
Expand Down

0 comments on commit 7f62e6e

Please sign in to comment.