From 6900805de283ec68fefac481f9eef5a7b70cab3c Mon Sep 17 00:00:00 2001 From: lisadunlap Date: Mon, 14 Oct 2024 17:24:13 +0000 Subject: [PATCH] formatting --- fastchat/serve/monitor/classify/category.py | 130 ++++++++++++++---- .../serve/monitor/classify/vision_config.yaml | 16 ++- 2 files changed, 116 insertions(+), 30 deletions(-) diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index 81fb2e4b3..954c941e9 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -190,8 +190,8 @@ def post_process(self, judgment): or bool(score == "both"), } -class CategoryCaptioning(Category): +class CategoryCaptioning(Category): def __init__(self): super().__init__() self.name_tag = "captioning_v0.1" @@ -199,7 +199,6 @@ def __init__(self): self.system_prompt = "You are tasked with determining if a given VQA question is a captioning question. A captioning question asks for a general, overall description of the entire image. It must be a single, open-ended query that does NOT ask about particular objects, people, or parts of the image, nor require interpretation beyond a broad description of what is visually present. Examples include 'What is happening in this image?', 'Describe this picture.', 'Explain', etc. An example of a non-captioning question is 'Describe what is funny in this picture.' because it asks for a specific interpretation of the image content. \n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." self.prompt_template = "\n{PROMPT}\n" - def get_score(self, judgment): matches = self.pattern.findall(judgment.replace("\n", "").lower()) matches = [m for m in matches if m != ""] @@ -221,9 +220,9 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"captioning": bool(score == "yes") if score else False} - -class CategoryCounting(Category): + +class CategoryCounting(Category): def __init__(self): super().__init__() self.name_tag = "counting_v0.1" @@ -252,16 +251,16 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"counting": bool(score == "yes") if score else False} - -class CategoryCreativeWriting(Category): + +class CategoryCreativeWriting(Category): def __init__(self): super().__init__() self.name_tag = "creative_writing_v0.1" self.pattern = re.compile(r"(\w+)<\/decision>") self.system_prompt = "You are tasked with determining if a given VQA question is a creative writing question. A creative writing question explicitly asks for creative or imaginative responses based on the image, such as composing a story, poem, or providing a fictional interpretation. This excludes questions that simply ask for factual observations, interpretations, or speculations about the image content.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." self.prompt_template = "\n{PROMPT}\n" - + def get_score(self, judgment): matches = self.pattern.findall(judgment.replace("\n", "").lower()) matches = [m for m in matches if m != ""] @@ -283,9 +282,9 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"creative_writing": bool(score == "yes") if score else False} - -class CategoryEntityRecognition(Category): + +class CategoryEntityRecognition(Category): def __init__(self): super().__init__() self.name_tag = "entity_recognition_v0.1" @@ -313,17 +312,20 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) return {"creative_writing": bool(score == "yes") if score else False} - + + import io import base64 + + def pil_to_base64(image): - buffered = io.BytesIO() - image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode() - return img_str - -class CategoryOpticalCharacterRecognition(Category): + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + +class CategoryOpticalCharacterRecognition(Category): def __init__(self): super().__init__() self.name_tag = "ocr_v0.1" @@ -362,12 +364,90 @@ def pre_process(self, prompt): def post_process(self, judgment): score = self.get_score(judgment=judgment) - return { - "is_captioning": "captioning" in score, - "is_counting": "counting" in score, - "is_ocr": "optical character recognition" in score, - "is_entity_recognition": "entity recognition" in score, - "is_creative_composition": "creative writing" in score, - "is_spatial_reasoning": "spatial reasoning" in score, - "response": judgment, - } + return {"ocr": bool(score == "yes") if score else False} + + +class CategoryHumor(Category): + def __init__(self): + super().__init__() + self.name_tag = "humor_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a humor question. A humor question asks for a humorous or funny response based on the image or asks to understand what is funny about an image. This includes questions that ask to explain an image which is humorous, such as memes.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + base64_image = get_image_file_from_gcs(prompt["image_hash"]) + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"humor": bool(score == "yes") if score else False} + + +class CategoryHomework(Category): + def __init__(self): + super().__init__() + self.name_tag = "homework_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = "You are tasked with determining if a given VQA question is a homework question. A homework question asks for explanations, solutions, or assistance with images that are likely from educational materials.\n\nOutput your verdict in the following format:\n[yes/no]\n. Do NOT explain." + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt["prompt"]} + base64_image = get_image_file_from_gcs(prompt["image_hash"]) + conv = [ + {"role": "system", "content": self.system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": self.prompt_template.format(**args)}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + }, + }, + ], + }, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"homework": bool(score == "yes") if score else False} diff --git a/fastchat/serve/monitor/classify/vision_config.yaml b/fastchat/serve/monitor/classify/vision_config.yaml index a7890c0e9..c1f7ad5a9 100644 --- a/fastchat/serve/monitor/classify/vision_config.yaml +++ b/fastchat/serve/monitor/classify/vision_config.yaml @@ -2,7 +2,7 @@ input_file: "../arena-data-analysis/data/vision_clean_battle_conv_20240822_with_image_hash.json" # json cache_file: False # json -output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-labeled.json" # json line +output_file: "fastchat/serve/monitor/classify/results/vision_clean_battle_conv_20240822_with_image_hash-new-per-category-llama.json" # json line convert_to_json: True @@ -10,12 +10,18 @@ task_name: - refusal_v0.1 - criteria_v0.1 - if_v0.1 - - vision_v0.1 + - captioning_v0.1 + - homework_v0.1 + - ocr_v0.1 + - counting_v0.1 + - humor_v0.1 + - entity_recognition_v0.1 + - creative_writing_v0.1 -model_name: gpt-4o -name: gpt-4o +model_name: gpt-4o-mini +name: gpt-4o-mini endpoints: - - api_base: # BASE URL + - api_base: # API BASE api_key: # API KEY parallel: 50