diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 2ef47b14d..bb95f6afe 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -601,6 +601,14 @@ def bot_response( font-size: 105%; } +#overview_leaderboard_dataframe table th { + font-size: 90%; +} + +#overview_leaderboard_dataframe table td { + font-size: 105%; +} + .tab-nav button { font-size: 18px; } diff --git a/fastchat/serve/monitor/add_markdown_info.py b/fastchat/serve/monitor/add_markdown_info.py index f05468ff9..a33ba1c9d 100644 --- a/fastchat/serve/monitor/add_markdown_info.py +++ b/fastchat/serve/monitor/add_markdown_info.py @@ -55,13 +55,75 @@ def get_element_counts(df, column): def add_markdown_meta(row): conv_meta = {k: v for k, v in row["conv_metadata"].items()} - return conv_meta | row["markdown_meta_a"] | row["markdown_meta_b"] + return conv_meta | row["markdown_meta_a"] | row["markdown_meta_b"] | { + "friendliness_a": row["friendliness_a"], + "friendliness_b": row["friendliness_b"], + } | { + "hesitant_a": row["hesitant_a"], + "hesitant_b": row["hesitant_b"], + } + +from fastchat.serve.monitor.utils_llm import get_llm_output +# doing things here +def get_score(judgment, pattern): + print(judgment) + matches = pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + +def add_friendliness(row, api_key): + systems_prompt = "Given a prompt and responses from two LLMs (A and B), your task is to determine which response is more friendly. If both responses are equally friendly, respond with equal. Do not let the order of the responses influence your decision. Do NOT reply or expand on any of the responses, only reply with which response is more friendly.\n\nOutput your verdict in the following format:\n[A/B/equal]\n. Do NOT explain." + prompt = f"Prompt: {row['prompt']}\n\nResponse A: {row['response_a']}\n\nResponse B: {row['response_b']}" + response = get_llm_output("gpt-4o-mini", api_key, systems_prompt, prompt) + pattern = response.replace("\n", "").lower().replace("", "").replace("", "").replace("[", "").replace("]", "").replace("(", "").replace(")", "").replace("<", "").replace(">", "") + if pattern == "a": + return 1 + elif pattern == "b": + return 0 + elif pattern == "equal": + return 0.5 + else: + print(f"Error: {pattern}") + return 0.5 + +def add_hesitant(row, api_key): + systems_prompt = "Given a prompt and responses from two LLMs (A and B), your task is to determine which response is more hesitant. Consider factors like adding disclaimers, stating when they are unsure, reccomending the user seek outside information instead, and refusing to give a complete answer. If both responses are equally hesitatnt, respond with equal. Do not let the order of the responses influence your decision. Do NOT reply or expand on any of the responses, only reply with which response is more hesitant.\n\nOutput your verdict in the following format:\n[A/B/equal]\n. Do NOT explain." + prompt = f"Prompt: {row['prompt']}\n\nResponse A: {row['response_a']}\n\nResponse B: {row['response_b']}" + response = get_llm_output("gpt-4o-mini", api_key, systems_prompt, prompt) + pattern = response.replace("\n", "").lower().replace("", "").replace("", "").replace("[", "").replace("]", "").replace("(", "").replace(")", "").replace("<", "").replace(">", "").strip() + if "a" in pattern and "b" not in pattern and "equal" not in pattern: + return 1 + elif "b" in pattern and "a" not in pattern and "equal" not in pattern: + return 0 + elif pattern == "equal": + return 0.5 + else: + print(f"Error: {pattern}") + return 0.5 + +from concurrent.futures import ThreadPoolExecutor, as_completed +def process_data_in_threads(dataframe, api_key, max_workers=64, function=add_friendliness): + with ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map( + lambda row: function(row, api_key), + [row for _, row in dataframe.iterrows()] + )) + return results + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input-file", type=str, required=True) parser.add_argument("--output-file", type=str, required=True) + parser.add_argument("--api_key", type=str, required=False) + parser.add_argument("--api_url", type=str, required=False) args = parser.parse_args() print("loading file...") @@ -71,14 +133,42 @@ def add_markdown_meta(row): temp = data[["question_id", "conv_metadata"]].copy() + data["prompt"] = data.conversation_a.map( + lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + ) + data["prompt"] = data.prompt.map(lambda x: x[:12500]) + get_content = lambda c: c if type(c) == str else c[0] + data["response_a"] = data.conversation_a.map( + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(1, len(convo), 2)] + ) + ) + data["response_a"] = data.response_a.map(lambda x: x[:12500]) + data["response_b"] = data.conversation_b.map( + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(1, len(convo), 2)] + ) + ) + data["response_b"] = data.response_b.map(lambda x: x[:12500]) + + print("Processing friendliness") + temp["friendliness_a"] = process_data_in_threads(data, args.api_key) + temp["friendliness_b"] = 1 - temp["friendliness_a"] + + print("Processing hesitant") + temp["hesitant_a"] = process_data_in_threads(data, args.api_key, function=add_hesitant) + temp["hesitant_b"] = 1 - temp["hesitant_a"] + print("Processing conversation_a") temp["markdown_meta_a"] = get_element_counts(data, column="conversation_a") print("Processing conversation_b") temp["markdown_meta_b"] = get_element_counts(data, column="conversation_b") + print(temp) print("Post-processing...") data["conv_metadata"] = temp.apply(add_markdown_meta, axis=1) + print(data.iloc[0]['conv_metadata']) print("Saving to file...") data.to_json(args.output_file, orient="records", indent=4, force_ascii=False) diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index 223144a32..954c941e9 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -11,6 +11,8 @@ import ast import re +from vision_utils import pil_to_base64, get_image_file_from_gcs + class Category: def __init__(self): @@ -24,6 +26,22 @@ def create_category(name): return CategoryIF() elif name == "math_v0.1": return CategoryMath() + elif name == "refusal_v0.1": + return CategoryRefusal() + elif name == "captioning_v0.1": + return CategoryCaptioning() + elif name == "counting_v0.1": + return CategoryCounting() + elif name == "creative_writing_v0.1": + return CategoryCreativeWriting() + elif name == "entity_recognition_v0.1": + return CategoryEntityRecognition() + elif name == "ocr_v0.1": + return CategoryOpticalCharacterRecognition() + elif name == "humor_v0.1": + return CategoryHumor() + elif name == "homework_v0.1": + return CategoryHomework() raise Exception(f"Category name is incorrect: {name}") @@ -63,7 +81,7 @@ def get_score(self, judgment): def pre_process(self, prompt): conv = [{"role": "system", "content": self.sys_prompt}] - conv.append({"role": "user", "content": prompt}) + conv.append({"role": "user", "content": prompt["prompt"]}) return conv def post_process(self, judgment): @@ -87,10 +105,11 @@ def get_score(self, judgment): elif len(set(matches)) == 1: return int(matches[0]) else: + print("Error parsing IF") return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -124,7 +143,7 @@ def get_score(self, judgment): return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -134,3 +153,301 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"math": bool(score == "yes") if score else False} + + +class CategoryRefusal(Category): + def __init__(self): + super().__init__() + self.name_tag = "refusal_v0.1" + self.pattern = re.compile(r"([a-z]+)") + self.system_prompt = 'You are tasked with determining if any of the given model responses are a refusal to answer. A refusal occurs when the model explicitly declines to answer or complete a task, due to reasons like safety, moderation, or model limitations (e.g. the user asks the model to search the web but it does not have that capability). You will be given the responces of 2 models, A and B and you are to determine if A refuses to answer, B refuses, both refuse, or neither refuse. Output your verdict in the following format:"\n[A/B/both/neither]\n". Do not respond with any additional information.' + + self.prompt_template = "\n{RESPONSE_A}\n\n\n{RESPONSE_B}\n" + + def get_score(self, judgment): + match = self.pattern.search( + judgment.replace("\n", "").replace("[", "").replace("]", "").lower() + ) + if match: + return match.group(1) + return "error" + + def pre_process(self, prompt): + args = {"RESPONSE_A": prompt["response_a"], "RESPONSE_B": prompt["response_a"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return { + "refusal_a": bool(score == "a") or bool(score == "both"), + "refusal_b": bool(score == "b") or bool(score == "both"), + "refusal": bool(score == "a") + or bool(score == "b") + or bool(score == "both"), + } + + +class CategoryCaptioning(Category): + def __init__(self): + super().__init__() + self.name_tag = "captioning_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a captioning question. A captioning question asks for a general, overall description of the entire image. It must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include 'What is happening in this image?', 'Describe this picture.', 'Explain', etc. An example of a non-captioning question is 'Describe what is funny in this picture.' because it asks for a specific interpretation of the image content. \n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"captioning": bool(score == "yes") if score else False} + + +class CategoryCounting(Category): + def __init__(self): + super().__init__() + self.name_tag = "counting_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a counting question. A counting question explicitly asks for counting or identifying the number of objects in the image. Your task is to analyze the question and determine if it is a counting question.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"counting": bool(score == "yes") if score else False} + + +class CategoryCreativeWriting(Category): + def __init__(self): + super().__init__() + self.name_tag = "creative_writing_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a creative writing question. A creative writing question explicitly asks for creative or imaginative responses based on the image, such as composing a story, poem, or providing a fictional interpretation. This excludes questions that simply ask for factual observations, interpretations, or speculations about the image content.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"creative_writing": bool(score == "yes") if score else False} + + +class CategoryEntityRecognition(Category): + def __init__(self): + super().__init__() + self.name_tag = "entity_recognition_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is an entity recognition question. An entity recognition question asks for the identification of specific objects or people in the image. This does NOT include questions that ask for a general description of the image, questions that only ask for object counts, or questions that only require reading text in the image.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"creative_writing": bool(score == "yes") if score else False} + + +import io +import base64 + + +def pil_to_base64(image): + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + + +class CategoryOpticalCharacterRecognition(Category): + def __init__(self): + super().__init__() + self.name_tag = "ocr_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is an optical character recognition (OCR) question. An OCR question requires reading and understanding text in the image to answer. If there is some amount of text in the image and the question requires reading the text in any capacity it should be classified as Optical Character Recognition.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + base64_image = get_image_file_from_gcs(prompt["image_hash"]) + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"ocr": bool(score == "yes") if score else False} + + +class CategoryHumor(Category): + def __init__(self): + super().__init__() + self.name_tag = "humor_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a humor question. A humor question asks for a humorous or funny response based on the image or asks to understand what is funny about an image. This includes questions that ask to explain an image which is humorous, such as memes.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + base64_image = get_image_file_from_gcs(prompt["image_hash"]) + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"humor": bool(score == "yes") if score else False} + + +class CategoryHomework(Category): + def __init__(self): + super().__init__() + self.name_tag = "homework_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a homework question. A homework question asks for explanations, solutions, or assistance with images that are likely from educational materials.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + base64_image = get_image_file_from_gcs(prompt["image_hash"]) + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"homework": bool(score == "yes") if score else False} diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index 2d0471a1f..d384e4371 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -9,8 +9,17 @@ import random import threading import orjson +import hashlib +from PIL import Image from category import Category +from vision_utils import get_image_path + +import lmdb + +if not os.path.exists("cache/category_cache"): + os.makedirs("cache/category_cache") +category_cache = lmdb.open("cache/category_cache", map_size=1024**4) LOCK = threading.RLock() @@ -56,7 +65,6 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: - # print(messages) completion = client.chat.completions.create( model=model, messages=messages, @@ -65,7 +73,6 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, ) output = completion.choices[0].message.content - # print(output) break except openai.RateLimitError as e: print(type(e), e) @@ -107,7 +114,7 @@ def get_answer( output_log = {} for category in categories: - conv = category.pre_process(question["prompt"]) + conv = category.pre_process(question) output = chat_completion_openai( model=model_name, messages=conv, @@ -125,7 +132,7 @@ def get_answer( if testing: question["output_log"] = output_log - question.drop(["prompt", "uid", "required_tasks"], inplace=True) + # question.drop(["prompt", "uid", "required_tasks"], inplace=True) with LOCK: with open(answer_file, "a") as fout: @@ -165,10 +172,14 @@ def find_required_tasks(row): ] +import wandb + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--testing", action="store_true") + parser.add_argument("--vision", action="store_true") + parser.add_argument("--wandb", action="store_true") args = parser.parse_args() enter = input( @@ -179,6 +190,14 @@ def find_required_tasks(row): config = make_config(args.config) + if not args.wandb: + os.environ["WANDB_MODE"] = "dryrun" + if args.wandb: + wandb.init( + project="arena", + name=config["input_file"].split("/")[-1].split(".")[0], + ) + API_MAX_RETRY = config["max_retry"] API_RETRY_SLEEP = config["retry_sleep"] API_ERROR_OUTPUT = config["error_output"] @@ -194,6 +213,26 @@ def find_required_tasks(row): data = orjson.loads(f.read()) input_data = pd.DataFrame(data) + if args.vision: + old_len = len(input_data) + input_data["image_hash"] = input_data.conversation_a.map( + lambda convo: convo[0]["content"][1][0] + ) + input_data["image_path"] = input_data.image_hash.map(get_image_path) + input_data = input_data[input_data.image_path != False].reset_index(drop=True) + print(f"{len(input_data)} out of {old_len}# images found") + + if args.testing: + # remove output file if exists + if os.path.isfile(config["output_file"]): + os.remove(config["output_file"]) + if "category_tag" in input_data.columns: + input_data.drop(columns=["category_tag"], inplace=True) + input_data = input_data[input_data["language"] == "English"].reset_index( + drop=True + ) + input_data = input_data[:50] + # much faster than pd.apply input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) assert len(input_data) == len(input_data.uid.unique()) @@ -246,10 +285,25 @@ def find_required_tasks(row): f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}" ) + get_content = lambda c: c if type(c) == str else c[0] not_labeled["prompt"] = not_labeled.conversation_a.map( - lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(0, len(convo), 2)] + ) ) not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) + not_labeled["response_a"] = not_labeled.conversation_a.map( + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(1, len(convo), 2)] + ) + ) + not_labeled["response_a"] = not_labeled.response_a.map(lambda x: x[:12500]) + not_labeled["response_b"] = not_labeled.conversation_b.map( + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(1, len(convo), 2)] + ) + ) + not_labeled["response_b"] = not_labeled.response_b.map(lambda x: x[:12500]) with concurrent.futures.ThreadPoolExecutor( max_workers=config["parallel"] @@ -277,6 +331,61 @@ def find_required_tasks(row): ): future.result() + output = pd.read_json(config["output_file"], lines=True) + + # log table to wandb + if args.wandb: + + def replace_none_in_nested_dict(d): + if isinstance(d, dict): + return {k: replace_none_in_nested_dict(v) for k, v in d.items()} + elif isinstance(d, list): + return [replace_none_in_nested_dict(v) for v in d] + elif d is None: + return -1 # Replace None with 0 + else: + return d + + def process_category_tag(df): + df["category_tag"] = df["category_tag"].apply(replace_none_in_nested_dict) + return df + + # Use this function before logging to wandb + output = process_category_tag(output) + columns = ( + ["prompt", "response_a", "response_b", "tstamp", "category_tag"] + if not args.vision + else [ + "prompt", + "image", + "response_a", + "response_b", + "tstamp", + "category_tag", + ] + ) + if args.vision: + # # read image_path into wandb Image + # output["image"] = output.image_path.map(lambda x: wandb.Image(x)) + def is_valid_image(filepath): + try: + Image.open(filepath).verify() + return True + except Exception: + print(f"Invalid image: {filepath}") + return False + + if args.testing: + output["image"] = output.image_path.map( + lambda x: wandb.Image(x) + if os.path.exists(x) and is_valid_image(x) + else None + ) + else: + output["image"] = output.image_path + + wandb.log({"categories": wandb.Table(dataframe=output[columns])}) + if config["convert_to_json"]: # merge two data frames, but only take the fields from the cache data to overwrite the input data merge_columns = [category.name_tag for category in categories] diff --git a/fastchat/serve/monitor/classify/vision_config.yaml b/fastchat/serve/monitor/classify/vision_config.yaml new file mode 100644 index 000000000..c1f7ad5a9 --- /dev/null +++ b/fastchat/serve/monitor/classify/vision_config.yaml @@ -0,0 +1,33 @@ +# Yaml config file for category classification + +input_file: "../arena-data-analysis/data/vision_clean_battle_conv_20240822_with_image_hash.json" # json +cache_file: False # json +output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-new-per-category-llama.json" # json line + +convert_to_json: True + +task_name: + - refusal_v0.1 + - criteria_v0.1 + - if_v0.1 + - captioning_v0.1 + - homework_v0.1 + - ocr_v0.1 + - counting_v0.1 + - humor_v0.1 + - entity_recognition_v0.1 + - creative_writing_v0.1 + +model_name: gpt-4o-mini +name: gpt-4o-mini +endpoints: + - api_base: # API BASE + api_key: # API KEY + +parallel: 50 +temperature: 0.0 +max_token: 512 + +max_retry: 2 +retry_sleep: 10 +error_output: $ERROR$ \ No newline at end of file diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py index 5a1daaa9d..67418ab54 100644 --- a/fastchat/serve/monitor/elo_analysis.py +++ b/fastchat/serve/monitor/elo_analysis.py @@ -26,10 +26,14 @@ "header_count_a", "list_count_a", "bold_count_a", + "friendliness_a", + "hesitant_a", "sum_assistant_b_tokens", "header_count_b", "list_count_b", "bold_count_b", + "friendliness_b", + "hesitant_b", ] @@ -435,11 +439,12 @@ def fit_mle_elo(X, Y, models, indices=None, SCALE=400, INIT_RATING=1000): def construct_style_matrices( df, BASE=10, - apply_ratio=[1, 1, 1, 1], + apply_ratio=[1, 1, 1, 1, 1], style_elements=STYLE_CONTROL_ELEMENTS_V1, add_one=True, ): models = pd.concat([df["model_a"], df["model_b"]]).unique() + print(df.columns) models = pd.Series(np.arange(len(models)), index=models) # duplicate battles @@ -458,7 +463,7 @@ def construct_style_matrices( [ df.conv_metadata.map( lambda x: x[element] - if type(x[element]) is int + if type(x[element]) is int or type(x[element]) is float else sum(x[element].values()) ).tolist() for element in style_elements @@ -714,15 +719,34 @@ def pretty_print_elo_rating(rating): filter_func_map = { "full": lambda x: True, - "long": filter_long_conv, "chinese": lambda x: x["language"] == "Chinese", "english": lambda x: x["language"] == "English", + "russian": lambda x: x["language"] == "Russian", + "vietnamese": lambda x: x["language"] == "Vietnamese", + "multiturn": lambda x: x["turn"] > 1, + "exclude_preset": lambda x: not x["preset"], + "no_refusal": lambda x: not x['category_tag']['refusal_v0.1']['refusal'], + "captioning": lambda x: x["category_tag"]["captioning_v0.1"]["captioning"], + "entity_recognition": lambda x: x["category_tag"]["entity_recognition_v0.1"]['entity_recognition'], + "ocr": lambda x: x["category_tag"]["ocr_v0.1"]["ocr"], + "if": lambda x: x["category_tag"]["if_v0.1"]["if"], + "math": lambda x: x["category_tag"]["math_v0.1"]["math"], + "homework": lambda x: x["category_tag"]["homework_v0.1"]["homework"], + "humor": lambda x: x["category_tag"]["humor_v0.1"]["humor"], + "coding": lambda x: x["is_coding"], } assert all( [cat in filter_func_map for cat in args.category] ), f"Invalid category: {args.category}" results = {} + for cat in args.category: + values = battles.apply(filter_func_map[cat], axis=1) + # if all values are False, skip + print(f"Category {cat} has {values.sum()} battles") + if not any(values): + print(f"Skipping category {cat}") + continue for cat in args.category: filter_func = filter_func_map[cat] results[cat] = report_elo_analysis_results( @@ -756,3 +780,5 @@ def pretty_print_elo_rating(rating): with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: pickle.dump(results, fout) + + print(f"saved elo_results_{cutoff_date}.pkl") diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index f5842c7c4..75990ad7a 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -674,12 +674,13 @@ def highlight_top_3(s): style = style.apply(highlight_top_3, subset=[category]) if metric == "rating": - style = style.background_gradient( - cmap="Blues", - subset=category_names, - vmin=1150, - vmax=category_df[category_names].max().max(), - ) + for category in category_names: + style = style.background_gradient( + cmap="Blues", + subset=[category], + vmin=category_df[category].max() - 150, + vmax=category_df[category].max(), + ) return style @@ -705,7 +706,7 @@ def build_category_leaderboard_tab( headers=["Model"] + [key_to_category_name[k] for k in categories], datatype=["markdown"] + ["str" for k in categories], value=full_table_vals, - elem_id="full_leaderboard_dataframe", + elem_id="overview_leaderboard_dataframe", column_widths=[250] + categories_width, # IMPORTANT: THIS IS HARDCODED WITH THE CURRENT CATEGORIES height=800, @@ -731,6 +732,17 @@ def build_category_leaderboard_tab( ] selected_categories_width = [95, 85, 130, 75, 150, 100, 95, 100] +vision_categories = [ + "full", + "is_captioning", + "is_entity_recognition", + "is_ocr", + "is_creative_composition", + "if", + "no_refusal", +] +vision_categories_width = [90, 90, 90, 50, 80, 95, 90, 80] + language_categories = [ "english", "chinese", @@ -838,6 +850,16 @@ def build_leaderboard_tab( language_categories, language_categories_width, ) + if elo_results_vision is not None: + vision_combined_table = get_combined_table( + elo_results_vision, model_table_df + ) + build_category_leaderboard_tab( + vision_combined_table, + "Vision", + vision_categories, + vision_categories_width, + ) gr.Markdown( f""" ***Rank (UB)**: model's ranking (upper-bound), defined by one + the number of models that are statistically better than the target model. diff --git a/fastchat/serve/monitor/monitor_md.py b/fastchat/serve/monitor/monitor_md.py index 4708c2b2c..37cb1d9a9 100644 --- a/fastchat/serve/monitor/monitor_md.py +++ b/fastchat/serve/monitor/monitor_md.py @@ -27,6 +27,15 @@ "no_refusal": "Exclude Refusal", "overall_limit_5_user_vote": "overall_limit_5_user_vote", "full_old": "Overall (Deprecated)", + "full_style_control": "Overall (Style Control)", + "hard_6_style_control": "Hard Prompts (Overall) (Style Control)", + "exclude_preset": "Exclude Preset", + "is_captioning": "Captioning", + "is_entity_recognition": "Entity Recognition", + "is_ocr": "OCR", + "is_counting": "Counting", + "is_creative_composition": "Creative Writing", + "is_spatial_reasoning": "Spatial Reasoning", } cat_name_to_explanation = { "Overall": "Overall Questions", @@ -51,6 +60,13 @@ "Exclude Refusal": 'Exclude model responses with refusal (e.g., "I cannot answer")', "overall_limit_5_user_vote": "overall_limit_5_user_vote", "Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", + "Exclude Preset": "Exclude Preset Images", + "Captioning": "Open-Ended Captioning", + "Entity Recognition": "Entity Recognition (e.g. who is in the image)", + "OCR": "Optical Character Recognition", + "Counting": "Counting", + "Creative Writing": "Creative Writing (e.g. write a story about this image)", + "Spatial Reasoning": "Spatial Reasoning", } cat_name_to_baseline = { "Hard Prompts (English)": "English",