Skip to content

Commit

Permalink
[Feat] add audiobench version of clothoaqa (#302)
Browse files Browse the repository at this point in the history
* add clothoaqa task

* formatting

* minor fixes

* minor fixes
  • Loading branch information
pbcong authored Oct 7, 2024
1 parent 17216e1 commit 454f6ef
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 75 deletions.
17 changes: 2 additions & 15 deletions lmms_eval/tasks/clotho_aqa/_default_template_yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,7 @@ dataset_kwargs:
doc_to_target: "answer"
doc_to_visual: !function utils.clotho_aqa_doc_to_audio
doc_to_text: !function utils.clotho_aqa_doc_to_text
generation_kwargs:
max_new_tokens: 16
temperature: 0
do_sample: False
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word."
metric_list:
- metric: exact_match
aggregation: !function utils.clotho_aqa_aggregate_results
higher_is_better: true
ignore_case: true
ignore_punctuation: true
process_results: !function utils.clotho_aqa_process_results

metadata:
gpt_eval_model_name: gpt-4o
version: 0.0
4 changes: 4 additions & 0 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: clotho_aqa
tasks:
- clotho_aqa_val
- clotho_aqa_test
19 changes: 19 additions & 0 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
task: "clotho_aqa_test"
dataset_name: "clotho_aqa"
test_split: clotho_aqa_test_filtered
generation_kwargs:
max_new_tokens: 8
temperature: 0
do_sample: False
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word only. "
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true

include: _default_template_yaml
4 changes: 0 additions & 4 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_test_all.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_test_majority.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_test_unanimous.yaml

This file was deleted.

19 changes: 19 additions & 0 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
task: "clotho_aqa_val"
dataset_name: "clotho_aqa"
test_split: clotho_aqa_val_filtered
generation_kwargs:
max_new_tokens: 8
temperature: 0
do_sample: False
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a single word only. "
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true

include: _default_template_yaml
4 changes: 0 additions & 4 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_val_all.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_val_majority.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions lmms_eval/tasks/clotho_aqa/clotho_aqa_val_unanimous.yaml

This file was deleted.

19 changes: 19 additions & 0 deletions lmms_eval/tasks/clotho_aqa/clotho_asqa_test_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
task: "clotho_asqa_test_v2"
dataset_name: "clotho_asqa_test_v2"
test_split: test
generation_kwargs:
max_new_tokens: 256
temperature: 0
do_sample: False
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: ""
metric_list:
- metric: gpt_eval
aggregation: !function utils.clotho_aqa_v2_aggregate_results
higher_is_better: true

process_results: !function utils.clotho_aqa_v2_process_results

include: _default_template_yaml
151 changes: 115 additions & 36 deletions lmms_eval/tasks/clotho_aqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import datetime
import json
import os
import re
import sys
import time
from pathlib import Path

import requests
import yaml
from loguru import logger as eval_logger

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
import lmms_eval.tasks._task_utils.file_utils as file_utils
from lmms_eval.filters.extraction import ExtendedRegexFilter


def clotho_aqa_doc_to_audio(doc):
Expand All @@ -17,40 +25,111 @@ def clotho_aqa_doc_to_text(doc, lmms_eval_specific_kwargs):
return f"{pre_prompt}{question}{post_prompt}"


def parse_pred_ans(pred_ans):
"""Brought from Otter Eval"""
pred_ans = pred_ans.lower().strip().replace(".", "")
pred_label = None
if len(pred_ans) == 1:
if pred_ans == "y":
pred_label = "yes"
elif pred_ans == "n":
pred_label = "no"
else:
pred_label = pred_ans
else:
if "yes" in pred_ans:
pred_label = "yes"
elif "no" in pred_ans:
pred_label = "no"
else:
pred_label = pred_ans
return pred_label


def clotho_aqa_process_results(doc, results):
pred = results[0]
pred_ans = parse_pred_ans(pred)
gt_ans = doc["answer"].lower().strip().replace(".", "")
score = 1.0 if pred_ans == gt_ans else 0.0
return {"exact_match": {"score": score}}


def clotho_aqa_aggregate_results(results):
correct = 0.0
total = 0.0
# functions for the clotho_asqa_v2 task, need to be tested later

with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

config = yaml.safe_load("".join(safe_data))


NUM_SECONDS_TO_SLEEP = 2
GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"]
API_TYPE = os.getenv("API_TYPE", "openai")

if API_TYPE == "openai":
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
}
elif API_TYPE == "azure":
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
headers = {
"api-key": API_KEY,
"Content-Type": "application/json",
}

eval_prompt = """
[Question]
{question}
[Reference Answer]
{ground_truth}
[Model Answer]
{model_response}
[Task]
Rate the model's answer based on its alignment with the reference answer, focusing on accuracy and relevance to the reference provided. Please be critical on the details.
Criteria: Assess if the model's response mirrors the reference in terms of content, accuracy, and relevance.
Score0: The answer is completely misaligned, providing incorrect or irrelevant information compared to the reference.
Score1: The answer shows minimal alignment, often misunderstanding or providing irrelevant details unrelated to the reference.
Score2: The answer recognizes the topic but diverges significantly from the reference in accuracy or relevance.
Score3: The answer aligns with the reference generally but lacks detail or precise accuracy in some aspects.
Score4: The answer is mostly accurate and relevant, closely following the reference but could be clearer or more detailed.
Score5: The answer is highly accurate, detailed, and matches the reference answer perfectly, capturing its essence and detail.
Your response should be formatted as follows:
Explanation: (Provide a concise explanation of your rating, comparing the reference answer with the model's response. "The reference answer is [XXX], while the model's answer is [YYY]. I think ...")
Rating: (int)"""


def get_eval(max_tokens: int, content: str):
global headers

messages = [
{"role": "user", "content": content},
]

payload = {
"model": GPT_EVAL_MODEL_NAME,
"messages": messages,
"temperature": 0.2,
"max_tokens": max_tokens,
}

try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
response.raise_for_status()
response_data = response.json()

content = response_data["choices"][0]["message"]["content"].strip()
if content != "":
return content, response_data["model"]
except Exception as e:
eval_logger.info(f"Attempt failed with error: {e}")
return "", ""
return "", ""


def clotho_aqa_v2_process_results(doc, result):
pred = result[0]
ground_truth_str = doc["answer"]
content = eval_prompt.format(model_response=pred, ground_truth=ground_truth_str, question=doc["question"])
eval_answer, model_name = get_eval(max_tokens=1024, content=content)
return {
"gpt_eval": {"eval_answer": eval_answer, "model_name": model_name},
}


def clotho_aqa_v2_aggregate_results(results):
score = 0
for result in results:
correct += result["score"]
total += 1
eval_answer = result["eval_answer"]
eval_score = re.search(r"([0-5])", eval_answer).group(1)
try:
eval_score = float(eval_score)
except Exception as e:
eval_logger.error(f"Error parsing eval_score: {e}")
eval_score = 0.0
score += eval_score

return correct / total * 100.0
return score / len(results) * 20

0 comments on commit 454f6ef

Please sign in to comment.