diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index 223144a32f..12efbf4b30 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -9,6 +9,8 @@ # - if # - score import ast +import base64 +import os import re @@ -24,6 +26,8 @@ def create_category(name): return CategoryIF() elif name == "math_v0.1": return CategoryMath() + elif name == "criteria_vision_v0.1": + return CategoryVisionHardPrompt() raise Exception(f"Category name is incorrect: {name}") @@ -134,3 +138,31 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"math": bool(score == "yes") if score else False} + + +class CategoryVisionHardPrompt(CategoryHardPrompt): + def __init__(self): + super().__init__() + self.name_tag = "criteria_vision_v0.1" + + def _convert_filepath_to_base64(self, filepath): + with open(filepath, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def pre_process(self, prompt: str, image_list: list): + # Prompt is a list where the first element is text and the second element is a list of image in base64 format + conv = [{"role": "system", "content": self.sys_prompt}] + single_turn_content_list = [] + single_turn_content_list.append({"type": "text", "text": prompt}) + for image_url in image_list: + single_turn_content_list.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{self._convert_filepath_to_base64(image_url)}" + }, + } + ) + + conv.append({"role": "user", "content": single_turn_content_list}) + return conv diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index 2d0471a1f1..deb15cc76f 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -107,7 +107,10 @@ def get_answer( output_log = {} for category in categories: - conv = category.pre_process(question["prompt"]) + if config["images_dir"]: + conv = category.pre_process(question["prompt"], question["image_list"]) + else: + conv = category.pre_process(question["prompt"]) output = chat_completion_openai( model=model_name, messages=conv, @@ -165,6 +168,34 @@ def find_required_tasks(row): ] +def aggregate_entire_conversation(conversation, images_dir): + final_text_content = "" + final_image_list = [] + + for i in range(0, len(conversation), 2): + content = conversation[i]["content"] + if isinstance(content, str): + final_text_content += "\n" + content + elif isinstance(content, list): + text_content, image_list = content + final_text_content += "\n" + text_content + + for image in image_list: + image_url = os.path.join(images_dir, f"{image}.png") + if os.path.exists(image_url): + final_image_list.append(image_url) + + return final_text_content, final_image_list + + +def get_prompt_from_conversation(conversation): + return conversation[0] + + +def get_image_list_from_conversation(conversation): + return conversation[1] + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) @@ -247,8 +278,15 @@ def find_required_tasks(row): ) not_labeled["prompt"] = not_labeled.conversation_a.map( - lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + lambda convo: aggregate_entire_conversation(convo, config["images_dir"]) ) + + if config["images_dir"]: + not_labeled["image_list"] = not_labeled.prompt.map( + get_image_list_from_conversation + ) + not_labeled = not_labeled[not_labeled.image_list.map(len) > 0] + not_labeled["prompt"] = not_labeled.prompt.map(get_prompt_from_conversation) not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) with concurrent.futures.ThreadPoolExecutor( diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py index bea808fc5d..b2fa24aab1 100644 --- a/fastchat/serve/monitor/elo_analysis.py +++ b/fastchat/serve/monitor/elo_analysis.py @@ -721,7 +721,10 @@ def pretty_print_elo_rating(rating): if args.clean_battle_file: # Read data from a cleaned battle files - battles = pd.read_json(args.clean_battle_file) + if args.clean_battle_file.endswith(".jsonl"): + battles = pd.read_json(args.clean_battle_file, lines=True) + else: + battles = pd.read_json(args.clean_battle_file) else: # Read data from all log files log_files = get_log_files(args.max_num_files) @@ -732,6 +735,10 @@ def pretty_print_elo_rating(rating): "long": filter_long_conv, "chinese": lambda x: x["language"] == "Chinese", "english": lambda x: x["language"] == "English", + "criteria_vision_v0.1": lambda x: sum( + x["category_tag"]["criteria_vision_v0.1"].values() + ) + >= 6, } assert all( [cat in filter_func_map for cat in args.category] diff --git a/pbcopy b/pbcopy new file mode 100644 index 0000000000..e69de29bb2