Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision Hard preliminary v0.0 #3525

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# - if
# - score
import ast
import base64
import os
import re


Expand All @@ -24,6 +26,8 @@ def create_category(name):
return CategoryIF()
elif name == "math_v0.1":
return CategoryMath()
elif name == "criteria_vision_v0.1":
return CategoryVisionHardPrompt()

raise Exception(f"Category name is incorrect: {name}")

Expand Down Expand Up @@ -134,3 +138,31 @@ def pre_process(self, prompt):
def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"math": bool(score == "yes") if score else False}


class CategoryVisionHardPrompt(CategoryHardPrompt):
def __init__(self):
super().__init__()
self.name_tag = "criteria_vision_v0.1"

def _convert_filepath_to_base64(self, filepath):
with open(filepath, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

def pre_process(self, prompt: str, image_list: list):
# Prompt is a list where the first element is text and the second element is a list of image in base64 format
conv = [{"role": "system", "content": self.sys_prompt}]
single_turn_content_list = []
single_turn_content_list.append({"type": "text", "text": prompt})
for image_url in image_list:
single_turn_content_list.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{self._convert_filepath_to_base64(image_url)}"
},
}
)

conv.append({"role": "user", "content": single_turn_content_list})
return conv
42 changes: 40 additions & 2 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def get_answer(
output_log = {}

for category in categories:
conv = category.pre_process(question["prompt"])
if config["images_dir"]:
conv = category.pre_process(question["prompt"], question["image_list"])
else:
conv = category.pre_process(question["prompt"])
output = chat_completion_openai(
model=model_name,
messages=conv,
Expand Down Expand Up @@ -165,6 +168,34 @@ def find_required_tasks(row):
]


def aggregate_entire_conversation(conversation, images_dir):
final_text_content = ""
final_image_list = []

for i in range(0, len(conversation), 2):
content = conversation[i]["content"]
if isinstance(content, str):
final_text_content += "\n" + content
elif isinstance(content, list):
text_content, image_list = content
final_text_content += "\n" + text_content

for image in image_list:
image_url = os.path.join(images_dir, f"{image}.png")
if os.path.exists(image_url):
final_image_list.append(image_url)

return final_text_content, final_image_list


def get_prompt_from_conversation(conversation):
return conversation[0]


def get_image_list_from_conversation(conversation):
return conversation[1]


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
Expand Down Expand Up @@ -247,8 +278,15 @@ def find_required_tasks(row):
)

not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
lambda convo: aggregate_entire_conversation(convo, config["images_dir"])
)

if config["images_dir"]:
not_labeled["image_list"] = not_labeled.prompt.map(
get_image_list_from_conversation
)
not_labeled = not_labeled[not_labeled.image_list.map(len) > 0]
not_labeled["prompt"] = not_labeled.prompt.map(get_prompt_from_conversation)
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])

with concurrent.futures.ThreadPoolExecutor(
Expand Down
9 changes: 8 additions & 1 deletion fastchat/serve/monitor/elo_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,10 @@ def pretty_print_elo_rating(rating):

if args.clean_battle_file:
# Read data from a cleaned battle files
battles = pd.read_json(args.clean_battle_file)
if args.clean_battle_file.endswith(".jsonl"):
battles = pd.read_json(args.clean_battle_file, lines=True)
else:
battles = pd.read_json(args.clean_battle_file)
else:
# Read data from all log files
log_files = get_log_files(args.max_num_files)
Expand All @@ -732,6 +735,10 @@ def pretty_print_elo_rating(rating):
"long": filter_long_conv,
"chinese": lambda x: x["language"] == "Chinese",
"english": lambda x: x["language"] == "English",
"criteria_vision_v0.1": lambda x: sum(
x["category_tag"]["criteria_vision_v0.1"].values()
)
>= 6,
}
assert all(
[cat in filter_func_map for cat in args.category]
Expand Down
Empty file added pbcopy
Empty file.
Loading