diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 2ef47b14df..bb95f6afe8 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -601,6 +601,14 @@ def bot_response( font-size: 105%; } +#overview_leaderboard_dataframe table th { + font-size: 90%; +} + +#overview_leaderboard_dataframe table td { + font-size: 105%; +} + .tab-nav button { font-size: 18px; } diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index 223144a32f..e3a3b47399 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -11,6 +11,8 @@ import ast import re +from vision_utils import pil_to_base64, get_image_file_from_gcs + class Category: def __init__(self): @@ -24,6 +26,10 @@ def create_category(name): return CategoryIF() elif name == "math_v0.1": return CategoryMath() + elif name == "refusal_v0.1": + return CategoryRefusal() + elif name == "vision_v0.1": + return CategoryVision() raise Exception(f"Category name is incorrect: {name}") @@ -63,7 +69,7 @@ def get_score(self, judgment): def pre_process(self, prompt): conv = [{"role": "system", "content": self.sys_prompt}] - conv.append({"role": "user", "content": prompt}) + conv.append({"role": "user", "content": prompt["prompt"]}) return conv def post_process(self, judgment): @@ -87,10 +93,11 @@ def get_score(self, judgment): elif len(set(matches)) == 1: return int(matches[0]) else: + print("Error parsing IF") return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -124,7 +131,7 @@ def get_score(self, judgment): return None def pre_process(self, prompt): - args = {"PROMPT": prompt} + args = {"PROMPT": prompt["prompt"]} conv = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt_template.format(**args)}, @@ -134,3 +141,99 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"math": bool(score == "yes") if score else False} + + +class CategoryRefusal(Category): + def __init__(self): + super().__init__() + self.name_tag = "refusal_v0.1" + self.pattern = re.compile(r"([a-z]+)") + self.system_prompt = 'You are tasked with determining if any of the given model responses are a refusal to answer. A refusal occurs when the model explicitly declines to answer or complete a task, due to reasons like safety, moderation, or model limitations (e.g. the user asks the model to search the web but it does not have that capability). You will be given the responces of 2 models, A and B and you are to determine if A refuses to answer, B refuses, both refuse, or neither refuse. Output your verdict in the following format:"\n[A/B/both/neither]\n". Do not respond with any additional information.' + + self.prompt_template = "\n{RESPONSE_A}\n\n\n{RESPONSE_B}\n" + + def get_score(self, judgment): + match = self.pattern.search( + judgment.replace("\n", "").replace("[", "").replace("]", "").lower() + ) + if match: + return match.group(1) + return "error" + + def pre_process(self, prompt): + args = {"RESPONSE_A": prompt["response_a"], "RESPONSE_B": prompt["response_a"]} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return { + "refusal_a": bool(score == "a") or bool(score == "both"), + "refusal_b": bool(score == "b") or bool(score == "both"), + "refusal": bool(score == "a") + or bool(score == "b") + or bool(score == "both"), + } + + +class CategoryVision(Category): + def __init__(self): + super().__init__() + self.name_tag = "vision_v0.1" + self.system_prompt = """You are an AI assistant specialized in classifying Visual Question Answering (VQA) questions into appropriate categories. When presented with a question or multiple questions about an image, you will analyze both the question and the image and categorize it based on the following criteria: + +Categories ([text only] means classification of this category should be based on the text question alone): +1. Captioning[text only]: Questions that ask for a general, overall description of the entire image. A captioning question must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include "What is happening in this image?", "Describe this picture.", "explain", etc. +2. Counting[text only]: Questions requiring counting or identifying the number of objects in the image. +3. Optical Character Recognition: Questions requiring reading and understanding text in the image to answer. If there is some amount of text in the image and the question requires reading the text in any capacity it should be classified as Optical Character Recognition. +4. Entity Recognition: Questions that ask for the identification of specific objects or people in the image. This does NOT include questions that ask for a general description of the image, questions that only ask for object counts, or questions that only require reading text in the image. +5. Spatial Reasoning[text only]: Questions that explicitly ask about the spatial relationships, locations, or arrangements of objects or elements within the image. This includes queries about relative positions (e.g., left, right, top, bottom), sizes, orientations, or distances between objects. +6. Creative Writing: Questions that explicitly ask for creative or imaginative responses based on the image, such as composing a story, poem, or providing a fictional interpretation. This excludes questions that simply ask for factual observations, interpretations, or speculations about the image content. + +Your task is to classify each question(s) into one or more of these categories. Provide your answer in the following format, with category names separated by commas and no additional information: + +{category name}, {category name} + +If none of the categories apply, enter 'Other'. + +Remember to consider all aspects of the question and assign all relevant categories. Do not answer the question, only classify it.""" + + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + return judgment.replace("\n", "").replace("[text only]", "").lower() + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + base64_image = get_image_file_from_gcs(prompt["image_hash"]) + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return { + "is_captioning": "captioning" in score, + "is_counting": "counting" in score, + "is_ocr": "optical character recognition" in score, + "is_entity_recognition": "entity recognition" in score, + "is_creative_composition": "creative writing" in score, + "is_spatial_reasoning": "spatial reasoning" in score, + "response": judgment, + } diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index 2d0471a1f1..ac876b173c 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -9,8 +9,17 @@ import random import threading import orjson +import hashlib +from PIL import Image from category import Category +from vision_utils import get_image_path + +import lmdb + +if not os.path.exists("cache/category_cache"): + os.makedirs("cache/category_cache") +category_cache = lmdb.open("cache/category_cache", map_size=1024**4) LOCK = threading.RLock() @@ -56,7 +65,6 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: - # print(messages) completion = client.chat.completions.create( model=model, messages=messages, @@ -65,7 +73,6 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, ) output = completion.choices[0].message.content - # print(output) break except openai.RateLimitError as e: print(type(e), e) @@ -107,7 +114,7 @@ def get_answer( output_log = {} for category in categories: - conv = category.pre_process(question["prompt"]) + conv = category.pre_process(question) output = chat_completion_openai( model=model_name, messages=conv, @@ -125,7 +132,7 @@ def get_answer( if testing: question["output_log"] = output_log - question.drop(["prompt", "uid", "required_tasks"], inplace=True) + # question.drop(["prompt", "uid", "required_tasks"], inplace=True) with LOCK: with open(answer_file, "a") as fout: @@ -165,10 +172,14 @@ def find_required_tasks(row): ] +import wandb + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--testing", action="store_true") + parser.add_argument("--vision", action="store_true") + parser.add_argument("--wandb", action="store_true") args = parser.parse_args() enter = input( @@ -179,6 +190,14 @@ def find_required_tasks(row): config = make_config(args.config) + if not args.wandb: + os.environ["WANDB_MODE"] = "dryrun" + if args.wandb: + wandb.init( + project="arena", + name=config["input_file"].split("/")[-1].split(".")[0], + ) + API_MAX_RETRY = config["max_retry"] API_RETRY_SLEEP = config["retry_sleep"] API_ERROR_OUTPUT = config["error_output"] @@ -194,6 +213,26 @@ def find_required_tasks(row): data = orjson.loads(f.read()) input_data = pd.DataFrame(data) + if args.vision: + old_len = len(input_data) + input_data["image_hash"] = input_data.conversation_a.map( + lambda convo: convo[0]["content"][1][0] + ) + input_data["image_path"] = input_data.image_hash.map(get_image_path) + input_data = input_data[input_data.image_path != False].reset_index(drop=True) + print(f"{len(input_data)} out of {old_len}# images found") + + if args.testing: + # remove output file if exists + if os.path.isfile(config["output_file"]): + os.remove(config["output_file"]) + if "category_tag" in input_data.columns: + input_data.drop(columns=["category_tag"], inplace=True) + input_data = input_data[input_data["language"] == "English"].reset_index( + drop=True + ) + input_data = input_data[:50] + # much faster than pd.apply input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) assert len(input_data) == len(input_data.uid.unique()) @@ -246,10 +285,25 @@ def find_required_tasks(row): f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}" ) + get_content = lambda c: c if type(c) == str else c[0] not_labeled["prompt"] = not_labeled.conversation_a.map( - lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(0, len(convo), 2)] + ) ) not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) + not_labeled["response_a"] = not_labeled.conversation_a.map( + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(1, len(convo), 2)] + ) + ) + not_labeled["response_a"] = not_labeled.response_a.map(lambda x: x[:12500]) + not_labeled["response_b"] = not_labeled.conversation_b.map( + lambda convo: "\n".join( + [get_content(convo[i]["content"]) for i in range(1, len(convo), 2)] + ) + ) + not_labeled["response_b"] = not_labeled.response_b.map(lambda x: x[:12500]) with concurrent.futures.ThreadPoolExecutor( max_workers=config["parallel"] @@ -277,7 +331,62 @@ def find_required_tasks(row): ): future.result() - if config["convert_to_json"]: + output = pd.read_json(config["output_file"], lines=True) + + # log table to wandb + if args.wandb: + + def replace_none_in_nested_dict(d): + if isinstance(d, dict): + return {k: replace_none_in_nested_dict(v) for k, v in d.items()} + elif isinstance(d, list): + return [replace_none_in_nested_dict(v) for v in d] + elif d is None: + return -1 # Replace None with 0 + else: + return d + + def process_category_tag(df): + df["category_tag"] = df["category_tag"].apply(replace_none_in_nested_dict) + return df + + # Use this function before logging to wandb + output = process_category_tag(output) + columns = ( + ["prompt", "response_a", "response_b", "tstamp", "category_tag"] + if not args.vision + else [ + "prompt", + "image", + "response_a", + "response_b", + "tstamp", + "category_tag", + ] + ) + if args.vision: + # # read image_path into wandb Image + # output["image"] = output.image_path.map(lambda x: wandb.Image(x)) + def is_valid_image(filepath): + try: + Image.open(filepath).verify() + return True + except Exception: + print(f"Invalid image: {filepath}") + return False + + if args.testing: + output["image"] = output.image_path.map( + lambda x: wandb.Image(x) + if os.path.exists(x) and is_valid_image(x) + else None + ) + else: + output["image"] = output.image_path + + wandb.log({"categories": wandb.Table(dataframe=output[columns])}) + + if config["convert_to_json"] and os.path.isfile(config["output_file"]): # merge two data frames, but only take the fields from the cache data to overwrite the input data merge_columns = [category.name_tag for category in categories] print(f"Columns to be merged:\n{merge_columns}") diff --git a/fastchat/serve/monitor/classify/vision_config.yaml b/fastchat/serve/monitor/classify/vision_config.yaml new file mode 100644 index 0000000000..a7890c0e98 --- /dev/null +++ b/fastchat/serve/monitor/classify/vision_config.yaml @@ -0,0 +1,27 @@ +# Yaml config file for category classification + +input_file: "../arena-data-analysis/data/vision_clean_battle_conv_20240822_with_image_hash.json" # json +cache_file: False # json +output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-labeled.json" # json line + +convert_to_json: True + +task_name: + - refusal_v0.1 + - criteria_v0.1 + - if_v0.1 + - vision_v0.1 + +model_name: gpt-4o +name: gpt-4o +endpoints: + - api_base: # BASE URL + api_key: # API KEY + +parallel: 50 +temperature: 0.0 +max_token: 512 + +max_retry: 2 +retry_sleep: 10 +error_output: $ERROR$ \ No newline at end of file diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py index 5a1daaa9dd..d9564bd2fa 100644 --- a/fastchat/serve/monitor/elo_analysis.py +++ b/fastchat/serve/monitor/elo_analysis.py @@ -714,15 +714,40 @@ def pretty_print_elo_rating(rating): filter_func_map = { "full": lambda x: True, - "long": filter_long_conv, "chinese": lambda x: x["language"] == "Chinese", "english": lambda x: x["language"] == "English", + "russian": lambda x: x["language"] == "Russian", + "vietnamese": lambda x: x["language"] == "Vietnamese", + "multiturn": lambda x: x["turn"] > 1, + "exclude_preset": lambda x: not x["preset"], + "no_refusal": lambda x: not x["is_refusal"], + "is_captioning": lambda x: x["category_tag"]["vision_v0.1"]["is_captioning"], + "is_entity_recognition": lambda x: x["category_tag"]["vision_v0.1"][ + "is_entity_recognition" + ], + "is_ocr": lambda x: x["category_tag"]["vision_v0.1"]["is_ocr"], + "is_counting": lambda x: x["category_tag"]["vision_v0.1"]["is_counting"], + "is_creative_composition": lambda x: x["category_tag"]["vision_v0.1"][ + "is_creative_composition" + ], + "is_spatial_reasoning": lambda x: x["category_tag"]["vision_v0.1"][ + "is_spatial_reasoning" + ], + "if": lambda x: x["category_tag"]["if_v0.1"]["if"], + "math": lambda x: x["category_tag"]["math_v0.1"]["math"], } assert all( [cat in filter_func_map for cat in args.category] ), f"Invalid category: {args.category}" results = {} + for cat in args.category: + values = battles.apply(filter_func_map[cat], axis=1) + # if all values are False, skip + print(f"Category {cat} has {values.sum()} battles") + if not any(values): + print(f"Skipping category {cat}") + continue for cat in args.category: filter_func = filter_func_map[cat] results[cat] = report_elo_analysis_results( @@ -756,3 +781,5 @@ def pretty_print_elo_rating(rating): with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: pickle.dump(results, fout) + + print(f"saved elo_results_{cutoff_date}.pkl") diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index f5842c7c4f..75990ad7a5 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -674,12 +674,13 @@ def highlight_top_3(s): style = style.apply(highlight_top_3, subset=[category]) if metric == "rating": - style = style.background_gradient( - cmap="Blues", - subset=category_names, - vmin=1150, - vmax=category_df[category_names].max().max(), - ) + for category in category_names: + style = style.background_gradient( + cmap="Blues", + subset=[category], + vmin=category_df[category].max() - 150, + vmax=category_df[category].max(), + ) return style @@ -705,7 +706,7 @@ def build_category_leaderboard_tab( headers=["Model"] + [key_to_category_name[k] for k in categories], datatype=["markdown"] + ["str" for k in categories], value=full_table_vals, - elem_id="full_leaderboard_dataframe", + elem_id="overview_leaderboard_dataframe", column_widths=[250] + categories_width, # IMPORTANT: THIS IS HARDCODED WITH THE CURRENT CATEGORIES height=800, @@ -731,6 +732,17 @@ def build_category_leaderboard_tab( ] selected_categories_width = [95, 85, 130, 75, 150, 100, 95, 100] +vision_categories = [ + "full", + "is_captioning", + "is_entity_recognition", + "is_ocr", + "is_creative_composition", + "if", + "no_refusal", +] +vision_categories_width = [90, 90, 90, 50, 80, 95, 90, 80] + language_categories = [ "english", "chinese", @@ -838,6 +850,16 @@ def build_leaderboard_tab( language_categories, language_categories_width, ) + if elo_results_vision is not None: + vision_combined_table = get_combined_table( + elo_results_vision, model_table_df + ) + build_category_leaderboard_tab( + vision_combined_table, + "Vision", + vision_categories, + vision_categories_width, + ) gr.Markdown( f""" ***Rank (UB)**: model's ranking (upper-bound), defined by one + the number of models that are statistically better than the target model. diff --git a/fastchat/serve/monitor/monitor_md.py b/fastchat/serve/monitor/monitor_md.py index 4708c2b2c7..37cb1d9a9d 100644 --- a/fastchat/serve/monitor/monitor_md.py +++ b/fastchat/serve/monitor/monitor_md.py @@ -27,6 +27,15 @@ "no_refusal": "Exclude Refusal", "overall_limit_5_user_vote": "overall_limit_5_user_vote", "full_old": "Overall (Deprecated)", + "full_style_control": "Overall (Style Control)", + "hard_6_style_control": "Hard Prompts (Overall) (Style Control)", + "exclude_preset": "Exclude Preset", + "is_captioning": "Captioning", + "is_entity_recognition": "Entity Recognition", + "is_ocr": "OCR", + "is_counting": "Counting", + "is_creative_composition": "Creative Writing", + "is_spatial_reasoning": "Spatial Reasoning", } cat_name_to_explanation = { "Overall": "Overall Questions", @@ -51,6 +60,13 @@ "Exclude Refusal": 'Exclude model responses with refusal (e.g., "I cannot answer")', "overall_limit_5_user_vote": "overall_limit_5_user_vote", "Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", + "Exclude Preset": "Exclude Preset Images", + "Captioning": "Open-Ended Captioning", + "Entity Recognition": "Entity Recognition (e.g. who is in the image)", + "OCR": "Optical Character Recognition", + "Counting": "Counting", + "Creative Writing": "Creative Writing (e.g. write a story about this image)", + "Spatial Reasoning": "Spatial Reasoning", } cat_name_to_baseline = { "Hard Prompts (English)": "English",