Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision categories #3515

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,14 @@ def bot_response(
font-size: 105%;
}

#overview_leaderboard_dataframe table th {
font-size: 90%;
}

#overview_leaderboard_dataframe table td {
font-size: 105%;
}

.tab-nav button {
font-size: 18px;
}
Expand Down
109 changes: 106 additions & 3 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import ast
import re

from vision_utils import pil_to_base64, get_image_file_from_gcs


class Category:
def __init__(self):
Expand All @@ -24,6 +26,10 @@ def create_category(name):
return CategoryIF()
elif name == "math_v0.1":
return CategoryMath()
elif name == "refusal_v0.1":
return CategoryRefusal()
elif name == "vision_v0.1":
return CategoryVision()

raise Exception(f"Category name is incorrect: {name}")

Expand Down Expand Up @@ -63,7 +69,7 @@ def get_score(self, judgment):

def pre_process(self, prompt):
conv = [{"role": "system", "content": self.sys_prompt}]
conv.append({"role": "user", "content": prompt})
conv.append({"role": "user", "content": prompt["prompt"]})
return conv

def post_process(self, judgment):
Expand All @@ -87,10 +93,11 @@ def get_score(self, judgment):
elif len(set(matches)) == 1:
return int(matches[0])
else:
print("Error parsing IF")
return None

def pre_process(self, prompt):
args = {"PROMPT": prompt}
args = {"PROMPT": prompt["prompt"]}
conv = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": self.prompt_template.format(**args)},
Expand Down Expand Up @@ -124,7 +131,7 @@ def get_score(self, judgment):
return None

def pre_process(self, prompt):
args = {"PROMPT": prompt}
args = {"PROMPT": prompt["prompt"]}
conv = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": self.prompt_template.format(**args)},
Expand All @@ -134,3 +141,99 @@ def pre_process(self, prompt):
def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"math": bool(score == "yes") if score else False}


class CategoryRefusal(Category):
def __init__(self):
super().__init__()
self.name_tag = "refusal_v0.1"
self.pattern = re.compile(r"<decision>([a-z]+)</decision>")
self.system_prompt = 'You are tasked with determining if any of the given model responses are a refusal to answer. A refusal occurs when the model explicitly declines to answer or complete a task, due to reasons like safety, moderation, or model limitations (e.g. the user asks the model to search the web but it does not have that capability). You will be given the responces of 2 models, A and B and you are to determine if A refuses to answer, B refuses, both refuse, or neither refuse. Output your verdict in the following format:"<decision>\n[A/B/both/neither]\n</decision>". Do not respond with any additional information.'

self.prompt_template = "<model_A_response>\n{RESPONSE_A}\n</model_A_response>\n<model_B_response>\n{RESPONSE_B}\n</model_B_response>"

def get_score(self, judgment):
match = self.pattern.search(
judgment.replace("\n", "").replace("[", "").replace("]", "").lower()
)
if match:
return match.group(1)
return "error"

def pre_process(self, prompt):
args = {"RESPONSE_A": prompt["response_a"], "RESPONSE_B": prompt["response_a"]}
conv = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": self.prompt_template.format(**args)},
]
return conv

def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {
"refusal_a": bool(score == "a") or bool(score == "both"),
"refusal_b": bool(score == "b") or bool(score == "both"),
"refusal": bool(score == "a")
or bool(score == "b")
or bool(score == "both"),
}


class CategoryVision(Category):
def __init__(self):
super().__init__()
self.name_tag = "vision_v0.1"
self.system_prompt = """You are an AI assistant specialized in classifying Visual Question Answering (VQA) questions into appropriate categories. When presented with a question or multiple questions about an image, you will analyze both the question and the image and categorize it based on the following criteria:

Categories ([text only] means classification of this category should be based on the text question alone):
1. Captioning[text only]: Questions that ask for a general, overall description of the entire image. A captioning question must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include "What is happening in this image?", "Describe this picture.", "explain", etc.
2. Counting[text only]: Questions requiring counting or identifying the number of objects in the image.
3. Optical Character Recognition: Questions requiring reading and understanding text in the image to answer. If there is some amount of text in the image and the question requires reading the text in any capacity it should be classified as Optical Character Recognition.
4. Entity Recognition: Questions that ask for the identification of specific objects or people in the image. This does NOT include questions that ask for a general description of the image, questions that only ask for object counts, or questions that only require reading text in the image.
5. Spatial Reasoning[text only]: Questions that explicitly ask about the spatial relationships, locations, or arrangements of objects or elements within the image. This includes queries about relative positions (e.g., left, right, top, bottom), sizes, orientations, or distances between objects.
6. Creative Writing: Questions that explicitly ask for creative or imaginative responses based on the image, such as composing a story, poem, or providing a fictional interpretation. This excludes questions that simply ask for factual observations, interpretations, or speculations about the image content.

Your task is to classify each question(s) into one or more of these categories. Provide your answer in the following format, with category names separated by commas and no additional information:

{category name}, {category name}

If none of the categories apply, enter 'Other'.

Remember to consider all aspects of the question and assign all relevant categories. Do not answer the question, only classify it."""

self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"

def get_score(self, judgment):
return judgment.replace("\n", "").replace("[text only]", "").lower()

def pre_process(self, prompt):
args = {"PROMPT": prompt["prompt"]}
base64_image = get_image_file_from_gcs(prompt["image_hash"])
conv = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": self.prompt_template.format(**args)},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
},
]
return conv

def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {
"is_captioning": "captioning" in score,
"is_counting": "counting" in score,
"is_ocr": "optical character recognition" in score,
"is_entity_recognition": "entity recognition" in score,
"is_creative_composition": "creative writing" in score,
"is_spatial_reasoning": "spatial reasoning" in score,
"response": judgment,
}
121 changes: 115 additions & 6 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@
import random
import threading
import orjson
import hashlib
from PIL import Image

from category import Category
from vision_utils import get_image_path

import lmdb

if not os.path.exists("cache/category_cache"):
os.makedirs("cache/category_cache")
category_cache = lmdb.open("cache/category_cache", map_size=1024**4)


LOCK = threading.RLock()
Expand Down Expand Up @@ -56,7 +65,6 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
# print(messages)
completion = client.chat.completions.create(
model=model,
messages=messages,
Expand All @@ -65,7 +73,6 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
# extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None,
)
output = completion.choices[0].message.content
# print(output)
break
except openai.RateLimitError as e:
print(type(e), e)
Expand Down Expand Up @@ -107,7 +114,7 @@ def get_answer(
output_log = {}

for category in categories:
conv = category.pre_process(question["prompt"])
conv = category.pre_process(question)
output = chat_completion_openai(
model=model_name,
messages=conv,
Expand All @@ -125,7 +132,7 @@ def get_answer(
if testing:
question["output_log"] = output_log

question.drop(["prompt", "uid", "required_tasks"], inplace=True)
# question.drop(["prompt", "uid", "required_tasks"], inplace=True)

with LOCK:
with open(answer_file, "a") as fout:
Expand Down Expand Up @@ -165,10 +172,14 @@ def find_required_tasks(row):
]


import wandb

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--testing", action="store_true")
parser.add_argument("--vision", action="store_true")
parser.add_argument("--wandb", action="store_true")
args = parser.parse_args()

enter = input(
Expand All @@ -179,6 +190,14 @@ def find_required_tasks(row):

config = make_config(args.config)

if not args.wandb:
os.environ["WANDB_MODE"] = "dryrun"
if args.wandb:
wandb.init(
project="arena",
name=config["input_file"].split("/")[-1].split(".")[0],
)

API_MAX_RETRY = config["max_retry"]
API_RETRY_SLEEP = config["retry_sleep"]
API_ERROR_OUTPUT = config["error_output"]
Expand All @@ -194,6 +213,26 @@ def find_required_tasks(row):
data = orjson.loads(f.read())
input_data = pd.DataFrame(data)

if args.vision:
old_len = len(input_data)
input_data["image_hash"] = input_data.conversation_a.map(
lambda convo: convo[0]["content"][1][0]
)
input_data["image_path"] = input_data.image_hash.map(get_image_path)
input_data = input_data[input_data.image_path != False].reset_index(drop=True)
print(f"{len(input_data)} out of {old_len}# images found")

if args.testing:
# remove output file if exists
if os.path.isfile(config["output_file"]):
os.remove(config["output_file"])
if "category_tag" in input_data.columns:
input_data.drop(columns=["category_tag"], inplace=True)
input_data = input_data[input_data["language"] == "English"].reset_index(
drop=True
)
input_data = input_data[:50]

# much faster than pd.apply
input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str)
assert len(input_data) == len(input_data.uid.unique())
Expand Down Expand Up @@ -246,10 +285,25 @@ def find_required_tasks(row):
f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}"
)

get_content = lambda c: c if type(c) == str else c[0]
not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
lambda convo: "\n".join(
[get_content(convo[i]["content"]) for i in range(0, len(convo), 2)]
)
)
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])
not_labeled["response_a"] = not_labeled.conversation_a.map(
lambda convo: "\n".join(
[get_content(convo[i]["content"]) for i in range(1, len(convo), 2)]
)
)
not_labeled["response_a"] = not_labeled.response_a.map(lambda x: x[:12500])
not_labeled["response_b"] = not_labeled.conversation_b.map(
lambda convo: "\n".join(
[get_content(convo[i]["content"]) for i in range(1, len(convo), 2)]
)
)
not_labeled["response_b"] = not_labeled.response_b.map(lambda x: x[:12500])

with concurrent.futures.ThreadPoolExecutor(
max_workers=config["parallel"]
Expand Down Expand Up @@ -277,7 +331,62 @@ def find_required_tasks(row):
):
future.result()

if config["convert_to_json"]:
output = pd.read_json(config["output_file"], lines=True)

# log table to wandb
if args.wandb:

def replace_none_in_nested_dict(d):
if isinstance(d, dict):
return {k: replace_none_in_nested_dict(v) for k, v in d.items()}
elif isinstance(d, list):
return [replace_none_in_nested_dict(v) for v in d]
elif d is None:
return -1 # Replace None with 0
else:
return d

def process_category_tag(df):
df["category_tag"] = df["category_tag"].apply(replace_none_in_nested_dict)
return df

# Use this function before logging to wandb
output = process_category_tag(output)
columns = (
["prompt", "response_a", "response_b", "tstamp", "category_tag"]
if not args.vision
else [
"prompt",
"image",
"response_a",
"response_b",
"tstamp",
"category_tag",
]
)
if args.vision:
# # read image_path into wandb Image
# output["image"] = output.image_path.map(lambda x: wandb.Image(x))
def is_valid_image(filepath):
try:
Image.open(filepath).verify()
return True
except Exception:
print(f"Invalid image: {filepath}")
return False

if args.testing:
output["image"] = output.image_path.map(
lambda x: wandb.Image(x)
if os.path.exists(x) and is_valid_image(x)
else None
)
else:
output["image"] = output.image_path

wandb.log({"categories": wandb.Table(dataframe=output[columns])})

if config["convert_to_json"] and os.path.isfile(config["output_file"]):
# merge two data frames, but only take the fields from the cache data to overwrite the input data
merge_columns = [category.name_tag for category in categories]
print(f"Columns to be merged:\n{merge_columns}")
Expand Down
27 changes: 27 additions & 0 deletions fastchat/serve/monitor/classify/vision_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Yaml config file for category classification

input_file: "../arena-data-analysis/data/vision_clean_battle_conv_20240822_with_image_hash.json" # json
cache_file: False # json
output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-labeled.json" # json line

convert_to_json: True

task_name:
- refusal_v0.1
- criteria_v0.1
- if_v0.1
- vision_v0.1

model_name: gpt-4o
name: gpt-4o
endpoints:
- api_base: # BASE URL
api_key: # API KEY

parallel: 50
temperature: 0.0
max_token: 512

max_retry: 2
retry_sleep: 10
error_output: $ERROR$
Loading
Loading