Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lisadunlap committed Oct 14, 2024
1 parent bf26fc8 commit 6900805
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 30 deletions.
130 changes: 105 additions & 25 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,15 @@ def post_process(self, judgment):
or bool(score == "both"),
}

class CategoryCaptioning(Category):

class CategoryCaptioning(Category):
def __init__(self):
super().__init__()
self.name_tag = "captioning_v0.1"
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
self.system_prompt = "You are tasked with determining if a given VQA question is a captioning question. A captioning question asks for a general, overall description of the entire image. It must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include 'What is happening in this image?', 'Describe this picture.', 'Explain', etc. An example of a non-captioning question is 'Describe what is funny in this picture.' because it asks for a specific interpretation of the image content. \n\nOutput your verdict in the following format:<decision>\n[yes/no]\n</decision>. Do NOT explain."
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"


def get_score(self, judgment):
matches = self.pattern.findall(judgment.replace("\n", "").lower())
matches = [m for m in matches if m != ""]
Expand All @@ -221,9 +220,9 @@ def pre_process(self, prompt):
def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"captioning": bool(score == "yes") if score else False}

class CategoryCounting(Category):


class CategoryCounting(Category):
def __init__(self):
super().__init__()
self.name_tag = "counting_v0.1"
Expand Down Expand Up @@ -252,16 +251,16 @@ def pre_process(self, prompt):
def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"counting": bool(score == "yes") if score else False}

class CategoryCreativeWriting(Category):


class CategoryCreativeWriting(Category):
def __init__(self):
super().__init__()
self.name_tag = "creative_writing_v0.1"
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
self.system_prompt = "You are tasked with determining if a given VQA question is a creative writing question. A creative writing question explicitly asks for creative or imaginative responses based on the image, such as composing a story, poem, or providing a fictional interpretation. This excludes questions that simply ask for factual observations, interpretations, or speculations about the image content.\n\nOutput your verdict in the following format:<decision>\n[yes/no]\n</decision>. Do NOT explain."
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"

def get_score(self, judgment):
matches = self.pattern.findall(judgment.replace("\n", "").lower())
matches = [m for m in matches if m != ""]
Expand All @@ -283,9 +282,9 @@ def pre_process(self, prompt):
def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"creative_writing": bool(score == "yes") if score else False}

class CategoryEntityRecognition(Category):


class CategoryEntityRecognition(Category):
def __init__(self):
super().__init__()
self.name_tag = "entity_recognition_v0.1"
Expand Down Expand Up @@ -313,17 +312,20 @@ def pre_process(self, prompt):
def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"creative_writing": bool(score == "yes") if score else False}



import io
import base64


def pil_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str

class CategoryOpticalCharacterRecognition(Category):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str


class CategoryOpticalCharacterRecognition(Category):
def __init__(self):
super().__init__()
self.name_tag = "ocr_v0.1"
Expand Down Expand Up @@ -362,12 +364,90 @@ def pre_process(self, prompt):

def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {
"is_captioning": "captioning" in score,
"is_counting": "counting" in score,
"is_ocr": "optical character recognition" in score,
"is_entity_recognition": "entity recognition" in score,
"is_creative_composition": "creative writing" in score,
"is_spatial_reasoning": "spatial reasoning" in score,
"response": judgment,
}
return {"ocr": bool(score == "yes") if score else False}


class CategoryHumor(Category):
def __init__(self):
super().__init__()
self.name_tag = "humor_v0.1"
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
self.system_prompt = "You are tasked with determining if a given VQA question is a humor question. A humor question asks for a humorous or funny response based on the image or asks to understand what is funny about an image. This includes questions that ask to explain an image which is humorous, such as memes.\n\nOutput your verdict in the following format:<decision>\n[yes/no]\n</decision>. Do NOT explain."
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"

def get_score(self, judgment):
matches = self.pattern.findall(judgment.replace("\n", "").lower())
matches = [m for m in matches if m != ""]
if len(set(matches)) == 0:
return None
elif len(set(matches)) == 1:
return matches[0]
else:
return None

def pre_process(self, prompt):
args = {"PROMPT": prompt["prompt"]}
base64_image = get_image_file_from_gcs(prompt["image_hash"])
conv = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": self.prompt_template.format(**args)},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
},
]
return conv

def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"humor": bool(score == "yes") if score else False}


class CategoryHomework(Category):
def __init__(self):
super().__init__()
self.name_tag = "homework_v0.1"
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
self.system_prompt = "You are tasked with determining if a given VQA question is a homework question. A homework question asks for explanations, solutions, or assistance with images that are likely from educational materials.\n\nOutput your verdict in the following format:<decision>\n[yes/no]\n</decision>. Do NOT explain."
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"

def get_score(self, judgment):
matches = self.pattern.findall(judgment.replace("\n", "").lower())
matches = [m for m in matches if m != ""]
if len(set(matches)) == 0:
return None
elif len(set(matches)) == 1:
return matches[0]
else:
return None

def pre_process(self, prompt):
args = {"PROMPT": prompt["prompt"]}
base64_image = get_image_file_from_gcs(prompt["image_hash"])
conv = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": self.prompt_template.format(**args)},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
},
]
return conv

def post_process(self, judgment):
score = self.get_score(judgment=judgment)
return {"homework": bool(score == "yes") if score else False}
16 changes: 11 additions & 5 deletions fastchat/serve/monitor/classify/vision_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@

input_file: "../arena-data-analysis/data/vision_clean_battle_conv_20240822_with_image_hash.json" # json
cache_file: False # json
output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-labeled.json" # json line
output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-new-per-category-llama.json" # json line

convert_to_json: True

task_name:
- refusal_v0.1
- criteria_v0.1
- if_v0.1
- vision_v0.1
- captioning_v0.1
- homework_v0.1
- ocr_v0.1
- counting_v0.1
- humor_v0.1
- entity_recognition_v0.1
- creative_writing_v0.1

model_name: gpt-4o
name: gpt-4o
model_name: gpt-4o-mini
name: gpt-4o-mini
endpoints:
- api_base: # BASE URL
- api_base: # API BASE
api_key: # API KEY

parallel: 50
Expand Down

0 comments on commit 6900805

Please sign in to comment.